switch request.Args to be a bitmap
This commit is contained in:
parent
79addd0b6e
commit
a1349b3889
4 changed files with 29 additions and 37 deletions
|
@ -77,7 +77,7 @@ type Request struct {
|
||||||
ID messageID
|
ID messageID
|
||||||
NodeID bitmap
|
NodeID bitmap
|
||||||
Method string
|
Method string
|
||||||
Args []string
|
Arg *bitmap
|
||||||
StoreArgs *storeArgs
|
StoreArgs *storeArgs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -85,8 +85,8 @@ func (r Request) MarshalBencode() ([]byte, error) {
|
||||||
var args interface{}
|
var args interface{}
|
||||||
if r.StoreArgs != nil {
|
if r.StoreArgs != nil {
|
||||||
args = r.StoreArgs
|
args = r.StoreArgs
|
||||||
} else {
|
} else if r.Arg != nil {
|
||||||
args = r.Args
|
args = []bitmap{*r.Arg}
|
||||||
}
|
}
|
||||||
return bencode.EncodeBytes(map[string]interface{}{
|
return bencode.EncodeBytes(map[string]interface{}{
|
||||||
headerTypeField: requestType,
|
headerTypeField: requestType,
|
||||||
|
@ -116,25 +116,26 @@ func (r *Request) UnmarshalBencode(b []byte) error {
|
||||||
if r.Method == storeMethod {
|
if r.Method == storeMethod {
|
||||||
r.StoreArgs = &storeArgs{} // bencode wont find the unmarshaler on a null pointer. need to fix it.
|
r.StoreArgs = &storeArgs{} // bencode wont find the unmarshaler on a null pointer. need to fix it.
|
||||||
err = bencode.DecodeBytes(raw.Args, &r.StoreArgs)
|
err = bencode.DecodeBytes(raw.Args, &r.StoreArgs)
|
||||||
} else {
|
if err != nil {
|
||||||
err = bencode.DecodeBytes(raw.Args, &r.Args)
|
return errors.Prefix("request unmarshal", err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
} else if len(raw.Args) > 2 { // 2 because an empty list is `le`
|
||||||
return errors.Prefix("request unmarshal", err)
|
tmp := []bitmap{}
|
||||||
|
err = bencode.DecodeBytes(raw.Args, &tmp)
|
||||||
|
if err != nil {
|
||||||
|
return errors.Prefix("request unmarshal", err)
|
||||||
|
}
|
||||||
|
r.Arg = &tmp[0]
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r Request) ArgsDebug() string {
|
func (r Request) ArgsDebug() string {
|
||||||
argsCopy := make([]string, len(r.Args))
|
if r.Arg != nil {
|
||||||
copy(argsCopy, r.Args)
|
return r.Arg.HexShort()
|
||||||
for k, v := range argsCopy {
|
|
||||||
if len(v) == nodeIDLength {
|
|
||||||
argsCopy[k] = hex.EncodeToString([]byte(v))[:8]
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return strings.Join(argsCopy, ", ")
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
type storeArgsValue struct {
|
type storeArgsValue struct {
|
||||||
|
|
|
@ -99,7 +99,7 @@ func (nf *nodeFinder) iterationWorker(num int) {
|
||||||
continue // cannot contact self
|
continue // cannot contact self
|
||||||
}
|
}
|
||||||
|
|
||||||
req := Request{Args: []string{nf.target.RawString()}}
|
req := Request{Arg: &nf.target}
|
||||||
if nf.findValue {
|
if nf.findValue {
|
||||||
req.Method = findValueMethod
|
req.Method = findValueMethod
|
||||||
} else {
|
} else {
|
||||||
|
|
23
dht/rpc.go
23
dht/rpc.go
|
@ -79,30 +79,22 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
||||||
case findNodeMethod:
|
case findNodeMethod:
|
||||||
if len(request.Args) != 1 {
|
if request.Arg == nil {
|
||||||
log.Errorln("invalid number of args")
|
log.Errorln("request is missing arg")
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(request.Args[0]) != nodeIDLength {
|
|
||||||
log.Errorln("invalid node id")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
doFindNodes(dht, addr, request)
|
doFindNodes(dht, addr, request)
|
||||||
case findValueMethod:
|
case findValueMethod:
|
||||||
if len(request.Args) != 1 {
|
if request.Arg == nil {
|
||||||
log.Errorln("invalid number of args")
|
log.Errorln("request is missing arg")
|
||||||
return
|
|
||||||
}
|
|
||||||
if len(request.Args[0]) != nodeIDLength {
|
|
||||||
log.Errorln("invalid blob hash")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if nodes := dht.store.Get(newBitmapFromString(request.Args[0])); len(nodes) > 0 {
|
if nodes := dht.store.Get(*request.Arg); len(nodes) > 0 {
|
||||||
send(dht, addr, Response{
|
send(dht, addr, Response{
|
||||||
ID: request.ID,
|
ID: request.ID,
|
||||||
NodeID: dht.node.id,
|
NodeID: dht.node.id,
|
||||||
FindValueKey: request.Args[0],
|
FindValueKey: request.Arg.RawString(),
|
||||||
FindNodeData: nodes,
|
FindNodeData: nodes,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
@ -120,8 +112,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
nodeID := newBitmapFromString(request.Args[0])
|
closestNodes := dht.rt.GetClosest(*request.Arg, bucketSize)
|
||||||
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
|
|
||||||
if len(closestNodes) > 0 {
|
if len(closestNodes) > 0 {
|
||||||
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
|
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
|
||||||
for i, n := range closestNodes {
|
for i, n := range closestNodes {
|
||||||
|
|
|
@ -303,13 +303,13 @@ func TestFindNode(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
blobHashToFind := newRandomBitmap().RawString()
|
blobHashToFind := newRandomBitmap()
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID,
|
NodeID: testNodeID,
|
||||||
Method: findNodeMethod,
|
Method: findNodeMethod,
|
||||||
Args: []string{blobHashToFind},
|
Arg: &blobHashToFind,
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := bencode.EncodeBytes(request)
|
data, err := bencode.EncodeBytes(request)
|
||||||
|
@ -394,7 +394,7 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID,
|
NodeID: testNodeID,
|
||||||
Method: findValueMethod,
|
Method: findValueMethod,
|
||||||
Args: []string{valueToFind.RawString()},
|
Arg: &valueToFind,
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := bencode.EncodeBytes(request)
|
data, err := bencode.EncodeBytes(request)
|
||||||
|
@ -466,13 +466,13 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
messageID := newMessageID()
|
messageID := newMessageID()
|
||||||
valueToFind := newRandomBitmap().RawString()
|
valueToFind := newRandomBitmap()
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID,
|
NodeID: testNodeID,
|
||||||
Method: findValueMethod,
|
Method: findValueMethod,
|
||||||
Args: []string{valueToFind},
|
Arg: &valueToFind,
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := bencode.EncodeBytes(request)
|
data, err := bencode.EncodeBytes(request)
|
||||||
|
|
Loading…
Reference in a new issue