lots more work. findnode should work now

This commit is contained in:
Alex Grintsvayg 2018-03-08 19:50:18 -05:00
parent 04ad1692d8
commit 75b3376305
9 changed files with 392 additions and 153 deletions

2
Gopkg.lock generated
View file

@ -190,7 +190,7 @@
branch = "master"
name = "github.com/zeebo/bencode"
packages = ["."]
revision = "1f43a06f6eb53936bc028b38cdd060f0b5629c6c"
revision = "d522839ac797fc43269dae6a04a1f8be475a915d"
[[projects]]
branch = "master"

34
dht/decode_test.go Normal file

File diff suppressed because one or more lines are too long

View file

@ -14,8 +14,18 @@ import (
)
const network = "udp4"
const bucketSize = 20
const alpha = 3 // this is the constant alpha in the spec
const nodeIDLength = 48 // bytes. this is the constant B in the spec
const bucketSize = 20 // this is the constant k in the spec
const tExpire = 86400 * time.Second // the time after which a key/value pair expires; this is a time-to-live (TTL) from the original publication date
const tRefresh = 3600 * time.Second // after which an otherwise unaccessed bucket must be refreshed
const tReplicate = 3600 * time.Second // the interval between Kademlia replication events, when a node is required to publish its entire database
const tRepublish = 86400 * time.Second // the time after which the original publisher must republish a key/value pair
const numBuckets = nodeIDLength * 8
const compactNodeInfoLength = nodeIDLength + 6
// packet represents the information receive from udp.
type packet struct {
@ -67,11 +77,21 @@ func New(config *Config) *DHT {
} else {
id = newBitmapFromHex(config.NodeID)
}
node := &Node{id: id, addr: config.Address}
ip, port, err := net.SplitHostPort(config.Address)
if err != nil {
panic(err)
}
portInt, err := cast.ToIntE(port)
if err != nil {
panic(err)
}
node := &Node{id: id, ip: ip, port: portInt}
return &DHT{
conf: config,
node: node,
routingTable: NewRoutingTable(node),
routingTable: newRoutingTable(node),
packets: make(chan packet),
store: newPeerStore(),
}
@ -217,33 +237,46 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool)
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
case findNodeMethod:
log.Println("findnode")
if len(request.Args) < 1 {
log.Errorln("nothing to find")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid node id")
return
}
nodeID := newBitmapFromString(request.Args[0])
closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize)
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes {
response.FindNodeData[i] = *n
}
send(dht, addr, response)
case findValueMethod:
log.Println("findvalue")
//if len(request.Args) < 1 {
// send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}})
// return
//}
//
//target := request.Args[0]
//if len(target) != nodeIDLength {
// send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"Invalid target"}})
// return
//}
//
//nodes := []findNodeDatum{}
//targetID := newBitmapFromString(target)
//
//no, _ := dht.routingTable.GetNodeKBucktByID(targetID)
//if no != nil {
// nodes = []findNodeDatum{{ID: no.id.RawString(), IP: no.addr.IP.String(), Port: no.addr.Port}}
//} else {
// neighbors := dht.routingTable.GetNeighbors(targetID, dht.K)
// for _, n := range neighbors {
// nodes = append(nodes, findNodeDatum{ID: n.id.RawString(), IP: n.addr.IP.String(), Port: n.addr.Port})
// }
//}
//
//send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: nodes})
if len(request.Args) < 1 {
log.Errorln("nothing to find")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid node id")
return
}
nodeIDs := dht.store.Get(request.Args[0])
if len(nodeIDs) > 0 {
// return node ids
} else {
// switch to findNode
}
nodeID := newBitmapFromString(request.Args[0])
closestNodes := dht.routingTable.FindClosest(nodeID, bucketSize)
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes {
response.FindNodeData[i] = *n
}
send(dht, addr, response)
default:
// send(dht, addr, makeError(t, protocolError, "invalid q"))
@ -251,7 +284,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool)
return
}
node := &Node{id: newBitmapFromString(request.NodeID), addr: addr.String()}
node := &Node{id: newBitmapFromString(request.NodeID), ip: addr.IP.String(), port: addr.Port}
dht.routingTable.Update(node)
return true
}
@ -271,7 +304,7 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success boo
// return
//}
node := &Node{id: newBitmapFromString(response.NodeID), addr: addr.String()}
node := &Node{id: newBitmapFromString(response.NodeID), ip: addr.IP.String(), port: addr.Port}
dht.routingTable.Update(node)
return true
@ -280,6 +313,8 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) (success boo
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) {
spew.Dump(e)
node := &Node{id: newBitmapFromString(e.NodeID), ip: addr.IP.String(), port: addr.Port}
dht.routingTable.Update(node)
return true
}
@ -288,7 +323,7 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], res.Data)
log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], spew.Sdump(res.Data))
} else {
log.Debugf("[%s] %s", spew.Sdump(data))
}
@ -305,7 +340,7 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
}
func getArgs(argsInt interface{}) []string {
args := []string{}
var args []string
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
v := reflect.ValueOf(argsInt)
for i := 0; i < v.Len(); i++ {
@ -316,10 +351,12 @@ func getArgs(argsInt interface{}) []string {
}
func argsToString(args []string) string {
for k, v := range args {
argsCopy := make([]string, len(args))
copy(argsCopy, args)
for k, v := range argsCopy {
if len(v) == nodeIDLength {
args[k] = hex.EncodeToString([]byte(v))[:8]
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
}
}
return strings.Join(args, ", ")
return strings.Join(argsCopy, ", ")
}

View file

@ -169,47 +169,9 @@ func TestStore(t *testing.T) {
}
}
if len(response) != 4 {
t.Errorf("expected 4 response fields, got %d", len(response))
}
verifyResponse(t, response, messageID, dhtNodeID.RawString())
_, ok := response[headerTypeField]
if !ok {
t.Error("missing type field")
} else {
rType, ok := response[headerTypeField].(int64)
if !ok {
t.Error("type is not an integer")
} else if rType != responseType {
t.Error("unexpected response type")
}
}
_, ok = response[headerMessageIDField]
if !ok {
t.Error("missing message id field")
} else {
rMessageID, ok := response[headerMessageIDField].(string)
if !ok {
t.Error("message ID is not a string")
} else if rMessageID != messageID {
t.Error("unexpected message ID")
}
}
_, ok = response[headerNodeIDField]
if !ok {
t.Error("missing node id field")
} else {
rNodeID, ok := response[headerNodeIDField].(string)
if !ok {
t.Error("node ID is not a string")
} else if rNodeID != dhtNodeID.RawString() {
t.Error("unexpected node ID")
}
}
_, ok = response[headerPayloadField]
_, ok := response[headerPayloadField]
if !ok {
t.Error("missing payload field")
} else {
@ -236,6 +198,7 @@ func TestStore(t *testing.T) {
func TestFindNode(t *testing.T) {
dhtNodeID := newRandomBitmap()
testNodeID := newRandomBitmap()
conn := newTestUDPConn("127.0.0.1:21217")
@ -244,23 +207,58 @@ func TestFindNode(t *testing.T) {
dht.listen()
go dht.runHandler()
data, _ := hex.DecodeString("64313a30693065313a3132303a2afdf2272981651a2c64e39ab7f04ec2d3b5d5d2313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33383a66696e644e6f6465313a346c34383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b6565")
nodesToInsert := 3
var nodes []Node
for i := 0; i < nodesToInsert; i++ {
n := Node{id: newRandomBitmap(), ip: "127.0.0.1", port: 10000 + i}
nodes = append(nodes, n)
dht.routingTable.Update(&n)
}
messageID := newRandomBitmap().RawString()
blobHashToFind := newRandomBitmap().RawString()
request := Request{
ID: messageID,
NodeID: testNodeID.RawString(),
Method: findNodeMethod,
Args: []string{blobHashToFind},
}
data, err := bencode.EncodeBytes(request)
if err != nil {
t.Error(err)
return
}
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
timer := time.NewTimer(3 * time.Second)
var response map[string]interface{}
select {
case <-timer.C:
t.Error("timeout")
return
case resp := <-conn.writes:
var response map[string]interface{}
err := bencode.DecodeBytes(resp.data, &response)
if err != nil {
t.Error(err)
return
}
}
spew.Dump(response)
verifyResponse(t, response, messageID, dhtNodeID.RawString())
_, ok := response[headerPayloadField]
if !ok {
t.Error("missing payload field")
} else {
contacts, ok := response[headerPayloadField].([]interface{})
if !ok {
t.Error("payload is not a list")
} else {
verifyContacts(t, contacts, nodes)
}
}
}
@ -293,3 +291,106 @@ func TestFindValue(t *testing.T) {
spew.Dump(response)
}
}
func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) {
if len(resp) != 4 {
t.Errorf("expected 4 response fields, got %d", len(resp))
}
_, ok := resp[headerTypeField]
if !ok {
t.Error("missing type field")
} else {
rType, ok := resp[headerTypeField].(int64)
if !ok {
t.Error("type is not an integer")
} else if rType != responseType {
t.Error("unexpected response type")
}
}
_, ok = resp[headerMessageIDField]
if !ok {
t.Error("missing message id field")
} else {
rMessageID, ok := resp[headerMessageIDField].(string)
if !ok {
t.Error("message ID is not a string")
} else if rMessageID != messageID {
t.Error("unexpected message ID")
}
}
_, ok = resp[headerNodeIDField]
if !ok {
t.Error("missing node id field")
} else {
rNodeID, ok := resp[headerNodeIDField].(string)
if !ok {
t.Error("node ID is not a string")
} else if rNodeID != dhtNodeID {
t.Error("unexpected node ID")
}
}
}
func verifyContacts(t *testing.T, contacts []interface{}, nodes []Node) {
if len(contacts) != len(nodes) {
t.Errorf("got %d contacts; expected %d", len(contacts), len(nodes))
return
}
foundNodes := make(map[string]bool)
for _, c := range contacts {
contact, ok := c.([]interface{})
if !ok {
t.Error("contact is not a list")
return
}
if len(contact) != 3 {
t.Error("contact must be 3 items")
return
}
var currNode Node
currNodeFound := false
id, ok := contact[0].(string)
if !ok {
t.Error("contact id is not a string")
} else {
if _, ok := foundNodes[id]; ok {
t.Errorf("contact %s appears multiple times", id)
continue
}
for _, n := range nodes {
if n.id.RawString() == id {
currNode = n
currNodeFound = true
foundNodes[id] = true
break
}
}
if !currNodeFound {
t.Errorf("unexpected contact %s", id)
continue
}
}
ip, ok := contact[1].(string)
if !ok {
t.Error("contact IP is not a string")
} else if ip != currNode.ip {
t.Errorf("contact IP mismatch. got %s; expected %s", ip, currNode.ip)
}
port, ok := contact[2].(int64)
if !ok {
t.Error("contact port is not an int")
} else if int(port) != currNode.port {
t.Errorf("contact port mismatch. got %d; expected %d", port, currNode.port)
}
}
}

View file

@ -163,44 +163,12 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error {
return nil
}
type findNodeDatum struct {
ID bitmap
IP string
Port int
}
func (f *findNodeDatum) UnmarshalBencode(b []byte) error {
var contact []bencode.RawMessage
err := bencode.DecodeBytes(b, &contact)
if err != nil {
return err
}
if len(contact) != 3 {
return errors.Err("invalid-sized contact")
}
err = bencode.DecodeBytes(contact[0], &f.ID)
if err != nil {
return err
}
err = bencode.DecodeBytes(contact[1], &f.IP)
if err != nil {
return err
}
err = bencode.DecodeBytes(contact[2], &f.Port)
if err != nil {
return err
}
return nil
}
type Response struct {
ID string
NodeID string
Data string
FindNodeData []findNodeDatum
FindNodeData []Node
FindValueKey string
}
func (r Response) MarshalBencode() ([]byte, error) {
@ -212,11 +180,7 @@ func (r Response) MarshalBencode() ([]byte, error) {
if r.Data != "" {
data[headerPayloadField] = r.Data
} else {
var nodes []interface{}
for _, n := range r.FindNodeData {
nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port})
}
data[headerPayloadField] = nodes
data[headerPayloadField] = r.FindNodeData
}
return bencode.EncodeBytes(data)
@ -226,7 +190,7 @@ func (r *Response) UnmarshalBencode(b []byte) error {
var raw struct {
ID string `bencode:"1"`
NodeID string `bencode:"2"`
Data bencode.RawMessage `bencode:"2"`
Data bencode.RawMessage `bencode:"3"`
}
err := bencode.DecodeBytes(b, &raw)
if err != nil {

View file

@ -1,20 +0,0 @@
package dht
const nodeIDLength = 48 // bytes
const compactNodeInfoLength = nodeIDLength + 6
type Node struct {
id bitmap
addr string
}
type SortedNode struct {
node *Node
sortKey bitmap
}
type byXorDistance []*SortedNode
func (a byXorDistance) Len() int { return len(a) }
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byXorDistance) Less(i, j int) bool { return a[i].sortKey.Less(a[j].sortKey) }

View file

@ -1,16 +1,113 @@
package dht
import (
"bytes"
"container/list"
"net"
"sort"
"github.com/lbryio/errors.go"
"github.com/zeebo/bencode"
)
type Node struct {
id bitmap
ip net.IP
port int
}
func (n Node) MarshalCompact() ([]byte, error) {
if n.ip.To4() == nil {
return nil, errors.Err("ip not set")
}
if n.port < 0 || n.port > 65535 {
return nil, errors.Err("invalid port")
}
var buf bytes.Buffer
buf.Write(n.ip.To4())
buf.WriteByte(byte(n.port >> 8))
buf.WriteByte(byte(n.port))
buf.Write(n.id[:])
if buf.Len() != nodeIDLength+6 {
return nil, errors.Err("i dont know how this happened")
}
return buf.Bytes(), nil
}
func (n *Node) UnmarshalCompact(b []byte) error {
if len(b) != 6 {
return errors.Err("invalid compact ip/port")
}
copy(n.ip, b[0:4])
n.port = int(uint16(b[5]) | uint16(b[4])<<8)
if n.port < 0 || n.port > 65535 {
return errors.Err("invalid port")
}
n.id = newBitmapFromBytes(b[6:])
return nil
}
func (n Node) MarshalBencode() ([]byte, error) {
return bencode.EncodeBytes([]interface{}{n.id, n.ip.String(), n.port})
}
func (n *Node) UnmarshalBencode(b []byte) error {
var raw []bencode.RawMessage
err := bencode.DecodeBytes(b, &raw)
if err != nil {
return err
}
if len(raw) != 3 {
return errors.Err("contact must have 3 elements; got %d", len(raw))
}
err = bencode.DecodeBytes(raw[0], &n.id)
if err != nil {
return err
}
var ipStr string
err = bencode.DecodeBytes(raw[1], &ipStr)
if err != nil {
return err
}
n.ip = net.ParseIP(ipStr).To4()
if n.ip == nil {
return errors.Err("invalid IP")
}
err = bencode.DecodeBytes(raw[2], &n.port)
if err != nil {
return err
}
return nil
}
type SortedNode struct {
node *Node
xorDistanceToTarget bitmap
}
type byXorDistance []*SortedNode
func (a byXorDistance) Len() int { return len(a) }
func (a byXorDistance) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a byXorDistance) Less(i, j int) bool {
return a[i].xorDistanceToTarget.Less(a[j].xorDistanceToTarget)
}
type RoutingTable struct {
node Node
buckets [numBuckets]*list.List
}
func NewRoutingTable(node *Node) *RoutingTable {
func newRoutingTable(node *Node) *RoutingTable {
var rt RoutingTable
for i := range rt.buckets {
rt.buckets[i] = list.New()
@ -35,26 +132,26 @@ func (rt *RoutingTable) Update(node *Node) {
}
func (rt *RoutingTable) FindClosest(target bitmap, count int) []*Node {
toSort := []*SortedNode{}
var toSort []*SortedNode
prefixLength := target.Xor(rt.node.id).PrefixLen()
bucket := rt.buckets[prefixLength]
appendNodes(bucket.Front(), nil, &toSort, target)
toSort = appendNodes(toSort, bucket.Front(), nil, target)
for i := 1; (prefixLength-i >= 0 || prefixLength+i < nodeIDLength*8) && len(toSort) < count; i++ {
if prefixLength-i >= 0 {
bucket = rt.buckets[prefixLength-i]
appendNodes(bucket.Front(), nil, &toSort, target)
toSort = appendNodes(toSort, bucket.Front(), nil, target)
}
if prefixLength+i < nodeIDLength*8 {
bucket = rt.buckets[prefixLength+i]
appendNodes(bucket.Front(), nil, &toSort, target)
toSort = appendNodes(toSort, bucket.Front(), nil, target)
}
}
sort.Sort(byXorDistance(toSort))
nodes := []*Node{}
var nodes []*Node
for _, c := range toSort {
nodes = append(nodes, c.node)
}
@ -71,9 +168,10 @@ func findInList(bucket *list.List, value bitmap) *list.Element {
return nil
}
func appendNodes(start, end *list.Element, nodes *[]*SortedNode, target bitmap) {
func appendNodes(nodes []*SortedNode, start, end *list.Element, target bitmap) []*SortedNode {
for curr := start; curr != end; curr = curr.Next() {
node := curr.Value.(*Node)
*nodes = append(*nodes, &SortedNode{node, node.id.Xor(target)})
nodes = append(nodes, &SortedNode{node, node.id.Xor(target)})
}
return nodes
}

View file

@ -1,14 +1,19 @@
package dht
import "testing"
import (
"net"
"testing"
"github.com/davecgh/go-spew/spew"
)
func TestRoutingTable(t *testing.T) {
n1 := newBitmapFromHex("FFFFFFFF0000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n2 := newBitmapFromHex("FFFFFFF00000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
n3 := newBitmapFromHex("111111110000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")
rt := NewRoutingTable(&Node{n1, "localhost:8000"})
rt.Update(&Node{n2, "localhost:8001"})
rt.Update(&Node{n3, "localhost:8002"})
rt := newRoutingTable(&Node{n1, net.ParseIP("127.0.0.1"), 8000})
rt.Update(&Node{n2, net.ParseIP("127.0.0.1"), 8001})
rt.Update(&Node{n3, net.ParseIP("127.0.0.1"), 8002})
contacts := rt.FindClosest(newBitmapFromHex("222222220000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"), 1)
if len(contacts) != 1 {
@ -31,3 +36,23 @@ func TestRoutingTable(t *testing.T) {
t.Error(contacts[1])
}
}
func TestCompactEncoding(t *testing.T) {
n := Node{
id: newBitmapFromHex("1c8aff71b99462464d9eeac639595ab99664be3482cb91a29d87467515c7d9158fe72aa1f1582dab07d8f8b5db277f41"),
ip: net.ParseIP("255.1.0.155"),
port: 66666,
}
var compact []byte
compact, err := n.MarshalCompact()
if err != nil {
t.Fatal(err)
}
if len(compact) != nodeIDLength+6 {
t.Fatalf("got length of %d; expected %d", len(compact), nodeIDLength+6)
}
spew.Dump(compact)
}

View file

@ -13,7 +13,7 @@ type peerStore struct {
func newPeerStore() *peerStore {
return &peerStore{
data: map[string][]peer{},
data: make(map[string][]peer),
}
}
@ -32,7 +32,7 @@ func (s *peerStore) Insert(key string, nodeId bitmap) {
func (s *peerStore) Get(key string) []bitmap {
s.lock.RLock()
defer s.lock.RUnlock()
nodes := []bitmap{}
var nodes []bitmap
if peers, ok := s.data[key]; ok {
for _, p := range peers {
nodes = append(nodes, p.nodeID)