switch request.Args to be a bitmap
This commit is contained in:
parent
dd8333db33
commit
d9acce359f
4 changed files with 29 additions and 37 deletions
|
@ -77,7 +77,7 @@ type Request struct {
|
|||
ID messageID
|
||||
NodeID bitmap
|
||||
Method string
|
||||
Args []string
|
||||
Arg *bitmap
|
||||
StoreArgs *storeArgs
|
||||
}
|
||||
|
||||
|
@ -85,8 +85,8 @@ func (r Request) MarshalBencode() ([]byte, error) {
|
|||
var args interface{}
|
||||
if r.StoreArgs != nil {
|
||||
args = r.StoreArgs
|
||||
} else {
|
||||
args = r.Args
|
||||
} else if r.Arg != nil {
|
||||
args = []bitmap{*r.Arg}
|
||||
}
|
||||
return bencode.EncodeBytes(map[string]interface{}{
|
||||
headerTypeField: requestType,
|
||||
|
@ -116,25 +116,26 @@ func (r *Request) UnmarshalBencode(b []byte) error {
|
|||
if r.Method == storeMethod {
|
||||
r.StoreArgs = &storeArgs{} // bencode wont find the unmarshaler on a null pointer. need to fix it.
|
||||
err = bencode.DecodeBytes(raw.Args, &r.StoreArgs)
|
||||
} else {
|
||||
err = bencode.DecodeBytes(raw.Args, &r.Args)
|
||||
}
|
||||
if err != nil {
|
||||
return errors.Prefix("request unmarshal", err)
|
||||
}
|
||||
} else if len(raw.Args) > 2 { // 2 because an empty list is `le`
|
||||
tmp := []bitmap{}
|
||||
err = bencode.DecodeBytes(raw.Args, &tmp)
|
||||
if err != nil {
|
||||
return errors.Prefix("request unmarshal", err)
|
||||
}
|
||||
r.Arg = &tmp[0]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r Request) ArgsDebug() string {
|
||||
argsCopy := make([]string, len(r.Args))
|
||||
copy(argsCopy, r.Args)
|
||||
for k, v := range argsCopy {
|
||||
if len(v) == nodeIDLength {
|
||||
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
|
||||
if r.Arg != nil {
|
||||
return r.Arg.HexShort()
|
||||
}
|
||||
}
|
||||
return strings.Join(argsCopy, ", ")
|
||||
return ""
|
||||
}
|
||||
|
||||
type storeArgsValue struct {
|
||||
|
|
|
@ -99,7 +99,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
|
|||
continue // cannot contact self
|
||||
}
|
||||
|
||||
req := Request{Args: []string{nf.target.RawString()}}
|
||||
req := Request{Arg: &nf.target}
|
||||
if nf.findValue {
|
||||
req.Method = findValueMethod
|
||||
} else {
|
||||
|
|
23
dht/rpc.go
23
dht/rpc.go
|
@ -79,30 +79,22 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
|||
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
||||
case findNodeMethod:
|
||||
if len(request.Args) != 1 {
|
||||
log.Errorln("invalid number of args")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid node id")
|
||||
if request.Arg == nil {
|
||||
log.Errorln("request is missing arg")
|
||||
return
|
||||
}
|
||||
doFindNodes(dht, addr, request)
|
||||
case findValueMethod:
|
||||
if len(request.Args) != 1 {
|
||||
log.Errorln("invalid number of args")
|
||||
return
|
||||
}
|
||||
if len(request.Args[0]) != nodeIDLength {
|
||||
log.Errorln("invalid blob hash")
|
||||
if request.Arg == nil {
|
||||
log.Errorln("request is missing arg")
|
||||
return
|
||||
}
|
||||
|
||||
if nodes := dht.store.Get(newBitmapFromString(request.Args[0])); len(nodes) > 0 {
|
||||
if nodes := dht.store.Get(*request.Arg); len(nodes) > 0 {
|
||||
send(dht, addr, Response{
|
||||
ID: request.ID,
|
||||
NodeID: dht.node.id,
|
||||
FindValueKey: request.Args[0],
|
||||
FindValueKey: request.Arg.RawString(),
|
||||
FindNodeData: nodes,
|
||||
})
|
||||
} else {
|
||||
|
@ -120,8 +112,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
|||
}
|
||||
|
||||
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||
nodeID := newBitmapFromString(request.Args[0])
|
||||
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
|
||||
closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize)
|
||||
if len(closestNodes) > 0 {
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
|
||||
for i, n := range closestNodes {
|
||||
|
|
|
@ -303,13 +303,13 @@ func TestFindNode(t *testing.T) {
|
|||
}
|
||||
|
||||
messageID := newMessageID()
|
||||
blobHashToFind := newRandomBitmap().RawString()
|
||||
blobHashToFind := newRandomBitmap()
|
||||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID,
|
||||
Method: findNodeMethod,
|
||||
Args: []string{blobHashToFind},
|
||||
Arg: &blobHashToFind,
|
||||
}
|
||||
|
||||
data, err := bencode.EncodeBytes(request)
|
||||
|
@ -394,7 +394,7 @@ func TestFindValueExisting(t *testing.T) {
|
|||
ID: messageID,
|
||||
NodeID: testNodeID,
|
||||
Method: findValueMethod,
|
||||
Args: []string{valueToFind.RawString()},
|
||||
Arg: &valueToFind,
|
||||
}
|
||||
|
||||
data, err := bencode.EncodeBytes(request)
|
||||
|
@ -466,13 +466,13 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
|||
}
|
||||
|
||||
messageID := newMessageID()
|
||||
valueToFind := newRandomBitmap().RawString()
|
||||
valueToFind := newRandomBitmap()
|
||||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID,
|
||||
Method: findValueMethod,
|
||||
Args: []string{valueToFind},
|
||||
Arg: &valueToFind,
|
||||
}
|
||||
|
||||
data, err := bencode.EncodeBytes(request)
|
||||
|
|
Loading…
Reference in a new issue