lots more work. findnode should work now
This commit is contained in:
parent
5c44ca40c2
commit
e13fe7c2f0
8 changed files with 391 additions and 152 deletions
34
dht/decode_test.go
Normal file
34
dht/decode_test.go
Normal file
File diff suppressed because one or more lines are too long
107
dht/dht.go
107
dht/dht.go
|
@ -14,8 +14,18 @@ import (
|
|||
)
|
||||
|
||||
const network = "udp4"
|
||||
const bucketSize = 20
|
||||
|
||||
const alpha = 3 // this is the constant alpha in the spec
|
||||
const nodeIDLength = 48 // bytes. this is the constant B in the spec
|
||||
const bucketSize = 20 // this is the constant k in the spec
|
||||
|
||||
const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
|
||||
const tRefresh = 3600 * time.Second // after which an otherwise unaccessed bucket must be refreshed
|
||||
const tReplicate = 3600 * time.Second // the interval between Kademlia replication events, when a node is required to publish its entire database
|
||||
const tRepublish = 86400 * time.Second // the time after which the original publisher must republish a key/value pair
|
||||
|
||||
const numBuckets = nodeIDLength * 8
|
||||
const compactNodeInfoLength = nodeIDLength + 6
|
||||
|
||||
// packet represents the information receive from udp.
|
||||
type packet struct {
|
||||
|
@ -67,11 +77,21 @@ func New(config *Config) *DHT {
|
|||
} else {
|
||||
id = newBitmapFromHex(config.NodeID)
|
||||
}
|
||||
node := &Node{id: id, addr: config.Address}
|
||||
|
||||
ip, port, err := net.SplitHostPort(config.Address)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
portInt, err := cast.ToIntE(port)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
node := &Node{id: id, ip: ip, port: portInt}
|
||||
return &DHT{
|
||||
conf: config,
|
||||
node: node,
|
||||
routingTable: NewRoutingTable(node),
|
||||
routingTable: newRoutingTable(node),
|
||||
packets: make(chan packet),
|
||||
store: newPeerStore(),
|
||||
}
|
||||
|
@ -217,33 +237,46 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool)
|
|||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
|
||||
case findNodeMethod:
|
||||
log.Println("findnode")
|
||||
if len(request.Args) < 1 {
|
||||
log.Errorln("nothing to find")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid node id")
|
||||
return
|
||||
}
|
||||
nodeID := newBitmapFromString(request.Args[0])
|
||||
closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize)
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
|
||||
for i, n := range closestNodes {
|
||||
response.FindNodeData[i] = *n
|
||||
}
|
||||
send(dht, addr, response)
|
||||
case findValueMethod:
|
||||
log.Println("findvalue")
|
||||
//if len(request.Args) < 1 {
|
||||
// send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}})
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//target := request.Args[0]
|
||||
//if len(target) != nodeIDLength {
|
||||
// send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid target"}})
|
||||
// return
|
||||
//}
|
||||
//
|
||||
//nodes := []findNodeDatum{}
|
||||
//targetID := newBitmapFromString(target)
|
||||
//
|
||||
//no, _ := dht.routingTable.GetNodeKBucktByID(targetID)
|
||||
//if no != nil {
|
||||
// nodes = []findNodeDatum{{ID: no.id.RawString(), IP: no.addr.IP.String(), Port: no.addr.Port}}
|
||||
//} else {
|
||||
// neighbors := dht.routingTable.GetNeighbors(targetID, dht.K)
|
||||
// for _, n := range neighbors {
|
||||
// nodes = append(nodes, findNodeDatum{ID: n.id.RawString(), IP: n.addr.IP.String(), Port: n.addr.Port})
|
||||
// }
|
||||
//}
|
||||
//
|
||||
//send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: nodes})
|
||||
if len(request.Args) < 1 {
|
||||
log.Errorln("nothing to find")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid node id")
|
||||
return
|
||||
}
|
||||
|
||||
nodeIDs := dht.store.Get(request.Args[0])
|
||||
if len(nodeIDs) > 0 {
|
||||
// return node ids
|
||||
} else {
|
||||
// switch to findNode
|
||||
}
|
||||
|
||||
nodeID := newBitmapFromString(request.Args[0])
|
||||
closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize)
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
|
||||
for i, n := range closestNodes {
|
||||
response.FindNodeData[i] = *n
|
||||
}
|
||||
send(dht, addr, response)
|
||||
|
||||
default:
|
||||
// send(dht, addr, makeError(t, protocolError, "invalid q"))
|
||||
|
@ -251,7 +284,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool)
|
|||
return
|
||||
}
|
||||
|
||||
node := &Node{id: newBitmapFromString(request.NodeID), addr: addr.String()}
|
||||
node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP.String(), port: addr.Port}
|
||||
dht.routingTable.Update(node)
|
||||
return true
|
||||
}
|
||||
|
@ -271,7 +304,7 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success boo
|
|||
// return
|
||||
//}
|
||||
|
||||
node := &Node{id: newBitmapFromString(response.NodeID), addr: addr.String()}
|
||||
node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP.String(), port: addr.Port}
|
||||
dht.routingTable.Update(node)
|
||||
|
||||
return true
|
||||
|
@ -280,6 +313,8 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success boo
|
|||
// handleError handles errors received from udp.
|
||||
func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) {
|
||||
spew.Dump(e)
|
||||
node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP.String(), port: addr.Port}
|
||||
dht.routingTable.Update(node)
|
||||
return true
|
||||
}
|
||||
|
||||
|
@ -288,7 +323,7 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
|||
if req, ok := data.(Request); ok {
|
||||
log.Debugf("[%s] query %s: sending request: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], req.Method, argsToString(req.Args))
|
||||
} else if res, ok := data.(Response); ok {
|
||||
log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], res.Data)
|
||||
log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], spew.Sdump(res.Data))
|
||||
} else {
|
||||
log.Debugf("[%s] %s", spew.Sdump(data))
|
||||
}
|
||||
|
@ -305,7 +340,7 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
|||
}
|
||||
|
||||
func getArgs(argsInt interface{}) []string {
|
||||
args := []string{}
|
||||
var args []string
|
||||
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(argsInt)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
|
@ -316,10 +351,12 @@ func getArgs(argsInt interface{}) []string {
|
|||
}
|
||||
|
||||
func argsToString(args []string) string {
|
||||
for k, v := range args {
|
||||
argsCopy := make([]string, len(args))
|
||||
copy(argsCopy, args)
|
||||
for k, v := range argsCopy {
|
||||
if len(v) == nodeIDLength {
|
||||
args[k] = hex.EncodeToString([]byte(v))[:8]
|
||||
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
|
||||
}
|
||||
}
|
||||
return strings.Join(args, ", ")
|
||||
return strings.Join(argsCopy, ", ")
|
||||
}
|
||||
|
|
187
dht/dht_test.go
187
dht/dht_test.go
|
@ -169,47 +169,9 @@ func TestStore(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
if len(response) != 4 {
|
||||
t.Errorf("expected 4 response fields, got %d", len(response))
|
||||
}
|
||||
verifyResponse(t, response, messageID, dhtNodeID.RawString())
|
||||
|
||||
_, ok := response[headerTypeField]
|
||||
if !ok {
|
||||
t.Error("missing type field")
|
||||
} else {
|
||||
rType, ok := response[headerTypeField].(int64)
|
||||
if !ok {
|
||||
t.Error("type is not an integer")
|
||||
} else if rType != responseType {
|
||||
t.Error("unexpected response type")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerMessageIDField]
|
||||
if !ok {
|
||||
t.Error("missing message id field")
|
||||
} else {
|
||||
rMessageID, ok := response[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != messageID {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerNodeIDField]
|
||||
if !ok {
|
||||
t.Error("missing node id field")
|
||||
} else {
|
||||
rNodeID, ok := response[headerNodeIDField].(string)
|
||||
if !ok {
|
||||
t.Error("node ID is not a string")
|
||||
} else if rNodeID != dhtNodeID.RawString() {
|
||||
t.Error("unexpected node ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = response[headerPayloadField]
|
||||
_, ok := response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Error("missing payload field")
|
||||
} else {
|
||||
|
@ -236,6 +198,7 @@ func TestStore(t *testing.T) {
|
|||
|
||||
func TestFindNode(t *testing.T) {
|
||||
dhtNodeID := newRandomBitmap()
|
||||
testNodeID := newRandomBitmap()
|
||||
|
||||
conn := newTestUDPConn("127.0.0.1:21217")
|
||||
|
||||
|
@ -244,23 +207,58 @@ func TestFindNode(t *testing.T) {
|
|||
dht.listen()
|
||||
go dht.runHandler()
|
||||
|
||||
data, _ := hex.DecodeString("64313a30693065313a3132303a2afdf2272981651a2c64e39ab7f04ec2d3b5d5d2313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33383a66696e644e6f6465313a346c34383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b6565")
|
||||
nodesToInsert := 3
|
||||
var nodes []Node
|
||||
for i := 0; i < nodesToInsert; i++ {
|
||||
n := Node{id: newRandomBitmap(), ip: "127.0.0.1", port: 10000 + i}
|
||||
nodes = append(nodes, n)
|
||||
dht.routingTable.Update(&n)
|
||||
}
|
||||
|
||||
messageID := newRandomBitmap().RawString()
|
||||
blobHashToFind := newRandomBitmap().RawString()
|
||||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID.RawString(),
|
||||
Method: findNodeMethod,
|
||||
Args: []string{blobHashToFind},
|
||||
}
|
||||
|
||||
data, err := bencode.EncodeBytes(request)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
|
||||
timer := time.NewTimer(3 * time.Second)
|
||||
|
||||
var response map[string]interface{}
|
||||
select {
|
||||
case <-timer.C:
|
||||
t.Error("timeout")
|
||||
return
|
||||
case resp := <-conn.writes:
|
||||
var response map[string]interface{}
|
||||
err := bencode.DecodeBytes(resp.data, &response)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
spew.Dump(response)
|
||||
verifyResponse(t, response, messageID, dhtNodeID.RawString())
|
||||
|
||||
_, ok := response[headerPayloadField]
|
||||
if !ok {
|
||||
t.Error("missing payload field")
|
||||
} else {
|
||||
contacts, ok := response[headerPayloadField].([]interface{})
|
||||
if !ok {
|
||||
t.Error("payload is not a list")
|
||||
} else {
|
||||
verifyContacts(t, contacts, nodes)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -293,3 +291,106 @@ func TestFindValue(t *testing.T) {
|
|||
spew.Dump(response)
|
||||
}
|
||||
}
|
||||
|
||||
func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) {
|
||||
if len(resp) != 4 {
|
||||
t.Errorf("expected 4 response fields, got %d", len(resp))
|
||||
}
|
||||
|
||||
_, ok := resp[headerTypeField]
|
||||
if !ok {
|
||||
t.Error("missing type field")
|
||||
} else {
|
||||
rType, ok := resp[headerTypeField].(int64)
|
||||
if !ok {
|
||||
t.Error("type is not an integer")
|
||||
} else if rType != responseType {
|
||||
t.Error("unexpected response type")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = resp[headerMessageIDField]
|
||||
if !ok {
|
||||
t.Error("missing message id field")
|
||||
} else {
|
||||
rMessageID, ok := resp[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != messageID {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
}
|
||||
|
||||
_, ok = resp[headerNodeIDField]
|
||||
if !ok {
|
||||
t.Error("missing node id field")
|
||||
} else {
|
||||
rNodeID, ok := resp[headerNodeIDField].(string)
|
||||
if !ok {
|
||||
t.Error("node ID is not a string")
|
||||
} else if rNodeID != dhtNodeID {
|
||||
t.Error("unexpected node ID")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) {
|
||||
if len(contacts) != len(nodes) {
|
||||
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
|
||||
return
|
||||
}
|
||||
|
||||
foundNodes := make(map[string]bool)
|
||||
|
||||
for _, c := range contacts {
|
||||
contact, ok := c.([]interface{})
|
||||
if !ok {
|
||||
t.Error("contact is not a list")
|
||||
return
|
||||
}
|
||||
|
||||
if len(contact) != 3 {
|
||||
t.Error("contact must be 3 items")
|
||||
return
|
||||
}
|
||||
|
||||
var currNode Node
|
||||
currNodeFound := false
|
||||
|
||||
id, ok := contact[0].(string)
|
||||
if !ok {
|
||||
t.Error("contact id is not a string")
|
||||
} else {
|
||||
if _, ok := foundNodes[id]; ok {
|
||||
t.Errorf("contact %s appears multiple times", id)
|
||||
continue
|
||||
}
|
||||
for _, n := range nodes {
|
||||
if n.id.RawString() == id {
|
||||
currNode = n
|
||||
currNodeFound = true
|
||||
foundNodes[id] = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !currNodeFound {
|
||||
t.Errorf("unexpected contact %s", id)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
ip, ok := contact[1].(string)
|
||||
if !ok {
|
||||
t.Error("contact IP is not a string")
|
||||
} else if ip != currNode.ip {
|
||||
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip)
|
||||
}
|
||||
|
||||
port, ok := contact[2].(int64)
|
||||
if !ok {
|
||||
t.Error("contact port is not an int")
|
||||
} else if int(port) != currNode.port {
|
||||
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -163,44 +163,12 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type findNodeDatum struct {
|
||||
ID bitmap
|
||||
IP string
|
||||
Port int
|
||||
}
|
||||
|
||||
func (f *findNodeDatum) UnmarshalBencode(b []byte) error {
|
||||
var contact []bencode.RawMessage
|
||||
err := bencode.DecodeBytes(b, &contact)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(contact) != 3 {
|
||||
return errors.Err("invalid-sized contact")
|
||||
}
|
||||
|
||||
err = bencode.DecodeBytes(contact[0], &f.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bencode.DecodeBytes(contact[1], &f.IP)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = bencode.DecodeBytes(contact[2], &f.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type Response struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Data string
|
||||
FindNodeData []findNodeDatum
|
||||
FindNodeData []Node
|
||||
FindValueKey string
|
||||
}
|
||||
|
||||
func (r Response) MarshalBencode() ([]byte, error) {
|
||||
|
@ -212,11 +180,7 @@ func (r Response) MarshalBencode() ([]byte, error) {
|
|||
if r.Data != "" {
|
||||
data[headerPayloadField] = r.Data
|
||||
} else {
|
||||
var nodes []interface{}
|
||||
for _, n := range r.FindNodeData {
|
||||
nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port})
|
||||
}
|
||||
data[headerPayloadField] = nodes
|
||||
data[headerPayloadField] = r.FindNodeData
|
||||
}
|
||||
|
||||
return bencode.EncodeBytes(data)
|
||||
|
@ -226,7 +190,7 @@ func (r *Response) UnmarshalBencode(b []byte) error {
|
|||
var raw struct {
|
||||
ID string `bencode:"1"`
|
||||
NodeID string `bencode:"2"`
|
||||
Data bencode.RawMessage `bencode:"2"`
|
||||
Data bencode.RawMessage `bencode:"3"`
|
||||
}
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
|
|
20
dht/node.go
20
dht/node.go
|
@ -1,20 +0,0 @@
|
|||
package dht
|
||||
|
||||
const nodeIDLength = 48 // bytes
|
||||
const compactNodeInfoLength = nodeIDLength + 6
|
||||
|
||||
type Node struct {
|
||||
id bitmap
|
||||
addr string
|
||||
}
|
||||
|
||||
type SortedNode struct {
|
||||
node *Node
|
||||
sortKey bitmap
|
||||
}
|
||||
|
||||
type byXorDistance []*SortedNode
|
||||
|
||||
func (a byXorDistance) Len() int { return len(a) }
|
||||
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a byXorDistance) Less(i, j int) bool { return a[i].sortKey.Less(a[j].sortKey) }
|
|
@ -1,16 +1,113 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"container/list"
|
||||
"net"
|
||||
"sort"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
|
||||
"github.com/zeebo/bencode"
|
||||
)
|
||||
|
||||
type Node struct {
|
||||
id bitmap
|
||||
ip net.IP
|
||||
port int
|
||||
}
|
||||
|
||||
func (n Node) MarshalCompact() ([]byte, error) {
|
||||
if n.ip.To4() == nil {
|
||||
return nil, errors.Err("ip not set")
|
||||
}
|
||||
if n.port < 0 || n.port > 65535 {
|
||||
return nil, errors.Err("invalid port")
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.Write(n.ip.To4())
|
||||
buf.WriteByte(byte(n.port >> 8))
|
||||
buf.WriteByte(byte(n.port))
|
||||
buf.Write(n.id[:])
|
||||
|
||||
if buf.Len() != nodeIDLength+6 {
|
||||
return nil, errors.Err("i dont know how this happened")
|
||||
}
|
||||
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func (n *Node) UnmarshalCompact(b []byte) error {
|
||||
if len(b) != 6 {
|
||||
return errors.Err("invalid compact ip/port")
|
||||
}
|
||||
copy(n.ip, b[0:4])
|
||||
n.port = int(uint16(b[5]) | uint16(b[4])<<8)
|
||||
if n.port < 0 || n.port > 65535 {
|
||||
return errors.Err("invalid port")
|
||||
}
|
||||
n.id = newBitmapFromBytes(b[6:])
|
||||
return nil
|
||||
}
|
||||
|
||||
func (n Node) MarshalBencode() ([]byte, error) {
|
||||
return bencode.EncodeBytes([]interface{}{n.id, n.ip.String(), n.port})
|
||||
}
|
||||
|
||||
func (n *Node) UnmarshalBencode(b []byte) error {
|
||||
var raw []bencode.RawMessage
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(raw) != 3 {
|
||||
return errors.Err("contact must have 3 elements; got %d", len(raw))
|
||||
}
|
||||
|
||||
err = bencode.DecodeBytes(raw[0], &n.id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var ipStr string
|
||||
err = bencode.DecodeBytes(raw[1], &ipStr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
n.ip = net.ParseIP(ipStr).To4()
|
||||
if n.ip == nil {
|
||||
return errors.Err("invalid IP")
|
||||
}
|
||||
|
||||
err = bencode.DecodeBytes(raw[2], &n.port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type SortedNode struct {
|
||||
node *Node
|
||||
xorDistanceToTarget bitmap
|
||||
}
|
||||
|
||||
type byXorDistance []*SortedNode
|
||||
|
||||
func (a byXorDistance) Len() int { return len(a) }
|
||||
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a byXorDistance) Less(i, j int) bool {
|
||||
return a[i].xorDistanceToTarget.Less(a[j].xorDistanceToTarget)
|
||||
}
|
||||
|
||||
type RoutingTable struct {
|
||||
node Node
|
||||
buckets [numBuckets]*list.List
|
||||
}
|
||||
|
||||
func NewRoutingTable(node *Node) *RoutingTable {
|
||||
func newRoutingTable(node *Node) *RoutingTable {
|
||||
var rt RoutingTable
|
||||
for i := range rt.buckets {
|
||||
rt.buckets[i] = list.New()
|
||||
|
@ -35,26 +132,26 @@ func (rt *RoutingTable) Update(node *Node) {
|
|||
}
|
||||
|
||||
func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node {
|
||||
toSort := []*SortedNode{}
|
||||
var toSort []*SortedNode
|
||||
|
||||
prefixLength := target.Xor(rt.node.id).PrefixLen()
|
||||
bucket := rt.buckets[prefixLength]
|
||||
appendNodes(bucket.Front(), nil, &toSort, target)
|
||||
toSort = appendNodes(toSort, bucket.Front(), nil, target)
|
||||
|
||||
for i := 1; (prefixLength-i >= 0 || prefixLength+i < nodeIDLength*8) && len(toSort) < count; i++ {
|
||||
if prefixLength-i >= 0 {
|
||||
bucket = rt.buckets[prefixLength-i]
|
||||
appendNodes(bucket.Front(), nil, &toSort, target)
|
||||
toSort = appendNodes(toSort, bucket.Front(), nil, target)
|
||||
}
|
||||
if prefixLength+i < nodeIDLength*8 {
|
||||
bucket = rt.buckets[prefixLength+i]
|
||||
appendNodes(bucket.Front(), nil, &toSort, target)
|
||||
toSort = appendNodes(toSort, bucket.Front(), nil, target)
|
||||
}
|
||||
}
|
||||
|
||||
sort.Sort(byXorDistance(toSort))
|
||||
|
||||
nodes := []*Node{}
|
||||
var nodes []*Node
|
||||
for _, c := range toSort {
|
||||
nodes = append(nodes, c.node)
|
||||
}
|
||||
|
@ -71,9 +168,10 @@ func findInList(bucket *list.List, value bitmap) *list.Element {
|
|||
return nil
|
||||
}
|
||||
|
||||
func appendNodes(start, end *list.Element, nodes *[]*SortedNode, target bitmap) {
|
||||
func appendNodes(nodes []*SortedNode, start, end *list.Element, target bitmap) []*SortedNode {
|
||||
for curr := start; curr != end; curr = curr.Next() {
|
||||
node := curr.Value.(*Node)
|
||||
*nodes = append(*nodes, &SortedNode{node, node.id.Xor(target)})
|
||||
nodes = append(nodes, &SortedNode{node, node.id.Xor(target)})
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
|
|
@ -1,14 +1,19 @@
|
|||
package dht
|
||||
|
||||
import "testing"
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
func TestRoutingTable(t *testing.T) {
|
||||
n1 := newBitmapFromHex("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
n2 := newBitmapFromHex("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
n3 := newBitmapFromHex("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
|
||||
rt := NewRoutingTable(&Node{n1, "localhost:8000"})
|
||||
rt.Update(&Node{n2, "localhost:8001"})
|
||||
rt.Update(&Node{n3, "localhost:8002"})
|
||||
rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000})
|
||||
rt.Update(&Node{n2, net.ParseIP("127.0.0.1"), 8001})
|
||||
rt.Update(&Node{n3, net.ParseIP("127.0.0.1"), 8002})
|
||||
|
||||
contacts := rt.FindClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
|
||||
if len(contacts) != 1 {
|
||||
|
@ -31,3 +36,23 @@ func TestRoutingTable(t *testing.T) {
|
|||
t.Error(contacts[1])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompactEncoding(t *testing.T) {
|
||||
n := Node{
|
||||
id: newBitmapFromHex("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
|
||||
ip: net.ParseIP("255.1.0.155"),
|
||||
port: 66666,
|
||||
}
|
||||
|
||||
var compact []byte
|
||||
compact, err := n.MarshalCompact()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(compact) != nodeIDLength+6 {
|
||||
t.Fatalf("got length of %d; expected %d", len(compact), nodeIDLength+6)
|
||||
}
|
||||
|
||||
spew.Dump(compact)
|
||||
}
|
||||
|
|
|
@ -13,7 +13,7 @@ type peerStore struct {
|
|||
|
||||
func newPeerStore() *peerStore {
|
||||
return &peerStore{
|
||||
data: map[string][]peer{},
|
||||
data: make(map[string][]peer),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -32,7 +32,7 @@ func (s *peerStore) Insert(key string, nodeId bitmap) {
|
|||
func (s *peerStore) Get(key string) []bitmap {
|
||||
s.lock.RLock()
|
||||
defer s.lock.RUnlock()
|
||||
nodes := []bitmap{}
|
||||
var nodes []bitmap
|
||||
if peers, ok := s.data[key]; ok {
|
||||
for _, p := range peers {
|
||||
nodes = append(nodes, p.nodeID)
|
||||
|
|
Loading…
Reference in a new issue