switch request.Args to be a bitmap

This commit is contained in:
Alex Grintsvayg 2018-04-04 12:01:44 -04:00
parent 79addd0b6e
commit a1349b3889
4 changed files with 29 additions and 37 deletions

View file

@ -77,7 +77,7 @@ type Request struct {
ID messageID ID messageID
NodeID bitmap NodeID bitmap
Method string Method string
Args []string Arg *bitmap
StoreArgs *storeArgs StoreArgs *storeArgs
} }
@ -85,8 +85,8 @@ func (r Request) MarshalBencode() ([]byte, error) {
var args interface{} var args interface{}
if r.StoreArgs != nil { if r.StoreArgs != nil {
args = r.StoreArgs args = r.StoreArgs
} else { } else if r.Arg != nil {
args = r.Args args = []bitmap{*r.Arg}
} }
return bencode.EncodeBytes(map[string]interface{}{ return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: requestType, headerTypeField: requestType,
@ -116,25 +116,26 @@ func (r *Request) UnmarshalBencode(b []byte) error {
if r.Method == storeMethod { if r.Method == storeMethod {
r.StoreArgs = &storeArgs{} // bencode wont find the unmarshaler on a null pointer. need to fix it. r.StoreArgs = &storeArgs{} // bencode wont find the unmarshaler on a null pointer. need to fix it.
err = bencode.DecodeBytes(raw.Args, &r.StoreArgs) err = bencode.DecodeBytes(raw.Args, &r.StoreArgs)
} else { if err != nil {
err = bencode.DecodeBytes(raw.Args, &r.Args) return errors.Prefix("request unmarshal", err)
} }
if err != nil { } else if len(raw.Args) > 2 { // 2 because an empty list is `le`
return errors.Prefix("request unmarshal", err) tmp := []bitmap{}
err = bencode.DecodeBytes(raw.Args, &tmp)
if err != nil {
return errors.Prefix("request unmarshal", err)
}
r.Arg = &tmp[0]
} }
return nil return nil
} }
func (r Request) ArgsDebug() string { func (r Request) ArgsDebug() string {
argsCopy := make([]string, len(r.Args)) if r.Arg != nil {
copy(argsCopy, r.Args) return r.Arg.HexShort()
for k, v := range argsCopy {
if len(v) == nodeIDLength {
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
}
} }
return strings.Join(argsCopy, ", ") return ""
} }
type storeArgsValue struct { type storeArgsValue struct {

View file

@ -99,7 +99,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
continue // cannot contact self continue // cannot contact self
} }
req := Request{Args: []string{nf.target.RawString()}} req := Request{Arg: &nf.target}
if nf.findValue { if nf.findValue {
req.Method = findValueMethod req.Method = findValueMethod
} else { } else {

View file

@ -79,30 +79,22 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse}) send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
case findNodeMethod: case findNodeMethod:
if len(request.Args) != 1 { if request.Arg == nil {
log.Errorln("invalid number of args") log.Errorln("request is missing arg")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid node id")
return return
} }
doFindNodes(dht, addr, request) doFindNodes(dht, addr, request)
case findValueMethod: case findValueMethod:
if len(request.Args) != 1 { if request.Arg == nil {
log.Errorln("invalid number of args") log.Errorln("request is missing arg")
return
}
if len(request.Args[0]) != nodeIDLength {
log.Errorln("invalid blob hash")
return return
} }
if nodes := dht.store.Get(newBitmapFromString(request.Args[0])); len(nodes) > 0 { if nodes := dht.store.Get(*request.Arg); len(nodes) > 0 {
send(dht, addr, Response{ send(dht, addr, Response{
ID: request.ID, ID: request.ID,
NodeID: dht.node.id, NodeID: dht.node.id,
FindValueKey: request.Args[0], FindValueKey: request.Arg.RawString(),
FindNodeData: nodes, FindNodeData: nodes,
}) })
} else { } else {
@ -120,8 +112,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
} }
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
nodeID := newBitmapFromString(request.Args[0]) closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize)
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
if len(closestNodes) > 0 { if len(closestNodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))} response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes { for i, n := range closestNodes {

View file

@ -303,13 +303,13 @@ func TestFindNode(t *testing.T) {
} }
messageID := newMessageID() messageID := newMessageID()
blobHashToFind := newRandomBitmap().RawString() blobHashToFind := newRandomBitmap()
request := Request{ request := Request{
ID: messageID, ID: messageID,
NodeID: testNodeID, NodeID: testNodeID,
Method: findNodeMethod, Method: findNodeMethod,
Args: []string{blobHashToFind}, Arg: &blobHashToFind,
} }
data, err := bencode.EncodeBytes(request) data, err := bencode.EncodeBytes(request)
@ -394,7 +394,7 @@ func TestFindValueExisting(t *testing.T) {
ID: messageID, ID: messageID,
NodeID: testNodeID, NodeID: testNodeID,
Method: findValueMethod, Method: findValueMethod,
Args: []string{valueToFind.RawString()}, Arg: &valueToFind,
} }
data, err := bencode.EncodeBytes(request) data, err := bencode.EncodeBytes(request)
@ -466,13 +466,13 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
} }
messageID := newMessageID() messageID := newMessageID()
valueToFind := newRandomBitmap().RawString() valueToFind := newRandomBitmap()
request := Request{ request := Request{
ID: messageID, ID: messageID,
NodeID: testNodeID, NodeID: testNodeID,
Method: findValueMethod, Method: findValueMethod,
Args: []string{valueToFind}, Arg: &valueToFind,
} }
data, err := bencode.EncodeBytes(request) data, err := bencode.EncodeBytes(request)