diff --git a/dht/message.go b/dht/message.go index fbd123f..cf66580 100644 --- a/dht/message.go +++ b/dht/message.go @@ -77,7 +77,7 @@ type Request struct { ID messageID NodeID bitmap Method string - Args []string + Arg *bitmap StoreArgs *storeArgs } @@ -85,8 +85,8 @@ func (r Request) MarshalBencode() ([]byte, error) { var args interface{} if r.StoreArgs != nil { args = r.StoreArgs - } else { - args = r.Args + } else if r.Arg != nil { + args = []bitmap{*r.Arg} } return bencode.EncodeBytes(map[string]interface{}{ headerTypeField: requestType, @@ -116,25 +116,26 @@ func (r *Request) UnmarshalBencode(b []byte) error { if r.Method == storeMethod { r.StoreArgs = &storeArgs{} // bencode wont find the unmarshaler on a null pointer. need to fix it. err = bencode.DecodeBytes(raw.Args, &r.StoreArgs) - } else { - err = bencode.DecodeBytes(raw.Args, &r.Args) - } - if err != nil { - return errors.Prefix("request unmarshal", err) + if err != nil { + return errors.Prefix("request unmarshal", err) + } + } else if len(raw.Args) > 2 { // 2 because an empty list is `le` + tmp := []bitmap{} + err = bencode.DecodeBytes(raw.Args, &tmp) + if err != nil { + return errors.Prefix("request unmarshal", err) + } + r.Arg = &tmp[0] } return nil } func (r Request) ArgsDebug() string { - argsCopy := make([]string, len(r.Args)) - copy(argsCopy, r.Args) - for k, v := range argsCopy { - if len(v) == nodeIDLength { - argsCopy[k] = hex.EncodeToString([]byte(v))[:8] - } + if r.Arg != nil { + return r.Arg.HexShort() } - return strings.Join(argsCopy, ", ") + return "" } type storeArgsValue struct { diff --git a/dht/node_finder.go b/dht/node_finder.go index 73ff523..dc43a83 100644 --- a/dht/node_finder.go +++ b/dht/node_finder.go @@ -99,7 +99,7 @@ func (nf *nodeFinder) iterationWorker(num int) { continue // cannot contact self } - req := Request{Args: []string{nf.target.RawString()}} + req := Request{Arg: &nf.target} if nf.findValue { req.Method = findValueMethod } else { diff --git a/dht/rpc.go b/dht/rpc.go index 60ccec9..7236f53 100644 --- a/dht/rpc.go +++ b/dht/rpc.go @@ -79,30 +79,22 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse}) case findNodeMethod: - if len(request.Args) != 1 { - log.Errorln("invalid number of args") - return - } - if len(request.Args[0]) != nodeIDLength { - log.Errorln("invalid node id") + if request.Arg == nil { + log.Errorln("request is missing arg") return } doFindNodes(dht, addr, request) case findValueMethod: - if len(request.Args) != 1 { - log.Errorln("invalid number of args") - return - } - if len(request.Args[0]) != nodeIDLength { - log.Errorln("invalid blob hash") + if request.Arg == nil { + log.Errorln("request is missing arg") return } - if nodes := dht.store.Get(newBitmapFromString(request.Args[0])); len(nodes) > 0 { + if nodes := dht.store.Get(*request.Arg); len(nodes) > 0 { send(dht, addr, Response{ ID: request.ID, NodeID: dht.node.id, - FindValueKey: request.Args[0], + FindValueKey: request.Arg.RawString(), FindNodeData: nodes, }) } else { @@ -120,8 +112,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { } func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { - nodeID := newBitmapFromString(request.Args[0]) - closestNodes := dht.rt.GetClosest(nodeID, bucketSize) + closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize) if len(closestNodes) > 0 { response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))} for i, n := range closestNodes { diff --git a/dht/rpc_test.go b/dht/rpc_test.go index 8129941..808a26d 100644 --- a/dht/rpc_test.go +++ b/dht/rpc_test.go @@ -303,13 +303,13 @@ func TestFindNode(t *testing.T) { } messageID := newMessageID() - blobHashToFind := newRandomBitmap().RawString() + blobHashToFind := newRandomBitmap() request := Request{ ID: messageID, NodeID: testNodeID, Method: findNodeMethod, - Args: []string{blobHashToFind}, + Arg: &blobHashToFind, } data, err := bencode.EncodeBytes(request) @@ -394,7 +394,7 @@ func TestFindValueExisting(t *testing.T) { ID: messageID, NodeID: testNodeID, Method: findValueMethod, - Args: []string{valueToFind.RawString()}, + Arg: &valueToFind, } data, err := bencode.EncodeBytes(request) @@ -466,13 +466,13 @@ func TestFindValueFallbackToFindNode(t *testing.T) { } messageID := newMessageID() - valueToFind := newRandomBitmap().RawString() + valueToFind := newRandomBitmap() request := Request{ ID: messageID, NodeID: testNodeID, Method: findValueMethod, - Args: []string{valueToFind}, + Arg: &valueToFind, } data, err := bencode.EncodeBytes(request)