From 006a49bd6716df6ff83cdedd1149fb5b62d1deba Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Wed, 7 Mar 2018 16:15:58 -0500 Subject: [PATCH] better tests, better bencoding --- dht/dht.go | 87 +++++---------- dht/dht_test.go | 106 ++++++++++++++++-- dht/message.go | 267 ++++++++++++++++++++++++++++++++++++++++++++ dht/message_test.go | 75 +++++++++++++ dht/messages.go | 103 ----------------- dht/store.go | 48 ++++++++ 6 files changed, 517 insertions(+), 169 deletions(-) create mode 100644 dht/message.go create mode 100644 dht/message_test.go delete mode 100644 dht/messages.go create mode 100644 dht/store.go diff --git a/dht/dht.go b/dht/dht.go index 0cdedda..600f327 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -52,6 +52,7 @@ type DHT struct { node *Node routingTable *RoutingTable packets chan packet + store *peerStore } // New returns a DHT pointer. If config is nil, then config will be set to the default config. @@ -72,6 +73,7 @@ func New(config *Config) *DHT { node: node, routingTable: NewRoutingTable(node), packets: make(chan packet), + store: newPeerStore(), } } @@ -150,39 +152,33 @@ func handle(dht *DHT, pkt packet) { var data map[string]interface{} err := bencode.DecodeBytes(pkt.data, &data) if err != nil { - log.Errorf("Error decoding data: %s\n%s", err, pkt.data) + log.Errorf("error decoding data: %s\n%s", err, pkt.data) return } msgType, ok := data[headerTypeField] if !ok { - log.Errorf("Decoded data has no message type: %s", data) + log.Errorf("decoded data has no message type: %s", data) return } switch msgType.(int64) { case requestType: - request := Request{ - ID: data[headerMessageIDField].(string), - NodeID: data[headerNodeIDField].(string), - Method: data[headerPayloadField].(string), - Args: getArgs(data[headerArgsField]), + request := Request{} + err = bencode.DecodeBytes(pkt.data, &request) + if err != nil { + return } - log.Infof("%s: Received from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) + log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) handleRequest(dht, pkt.raddr, request) case responseType: - response := Response{ - ID: data[headerMessageIDField].(string), - NodeID: data[headerNodeIDField].(string), + response := Response{} + err = bencode.DecodeBytes(pkt.data, &response) + if err != nil { + return } - - if reflect.TypeOf(data[headerPayloadField]).Kind() == reflect.String { - response.Data = data[headerPayloadField].(string) - } else { - response.FindNodeData = getFindNodeResponse(data[headerPayloadField]) - } - + log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data) handleResponse(dht, pkt.raddr, response) case errorType: @@ -192,6 +188,7 @@ func handle(dht *DHT, pkt packet) { ExceptionType: data[headerPayloadField].(string), Response: getArgs(data[headerArgsField]), } + log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType) handleError(dht, pkt.raddr, e) default: @@ -202,7 +199,6 @@ func handle(dht *DHT, pkt packet) { // handleRequest handles the requests received from udp. func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) { - log.Infoln("handling request") if request.NodeID == dht.node.id.RawString() { log.Warn("ignoring self-request") return @@ -211,11 +207,16 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) switch request.Method { case pingMethod: log.Println("ping") - send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: "pong"}) + send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse}) case storeMethod: log.Println("store") + node := &Node{id: newBitmapFromHex(request.StoreArgs.Value.LbryID), addr: request.StoreArgs.Value.Port} + dht.store.Insert(newBitmapFromHex(request.StoreArgs.BlobHash)) + send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse}) case findNodeMethod: log.Println("findnode") + case findValueMethod: + log.Println("findvalue") //if len(request.Args) < 1 { // send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}}) // return @@ -244,6 +245,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) default: // send(dht, addr, makeError(t, protocolError, "invalid q")) + log.Errorln("invalid request method") return } @@ -282,11 +284,13 @@ func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) { // send sends data to the udp. func send(dht *DHT, addr *net.UDPAddr, data Message) error { if req, ok := data.(Request); ok { - log.Infof("%s: Sending %s(%s)", hex.EncodeToString([]byte(req.NodeID))[:8], req.Method, argsToString(req.Args)) + log.Debugf("[%s] query %s: sending request: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], req.Method, argsToString(req.Args)) + } else if res, ok := data.(Response); ok { + log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], res.Data) } else { - log.Infof("%s: Sending %s", data.GetID(), spew.Sdump(data)) + log.Debugf("[%s] %s", spew.Sdump(data)) } - encoded, err := data.Encode() + encoded, err := bencode.EncodeBytes(data) if err != nil { return err } @@ -298,46 +302,15 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error { return err } -func getFindNodeResponse(i interface{}) (data []findNodeDatum) { - if reflect.TypeOf(i).Kind() != reflect.Slice { - return - } - - v := reflect.ValueOf(i) - for i := 0; i < v.Len(); i++ { - if v.Index(i).Kind() != reflect.Interface { - continue - } - - contact := v.Index(i).Elem() - if contact.Type().Kind() != reflect.Slice || contact.Len() != 3 { - continue - } - - if contact.Index(0).Elem().Kind() != reflect.String || - contact.Index(1).Elem().Kind() != reflect.String || - !(contact.Index(2).Elem().Kind() == reflect.Int64 || - contact.Index(2).Elem().Kind() == reflect.Int) { - continue - } - - data = append(data, findNodeDatum{ - ID: contact.Index(0).Elem().String(), - IP: contact.Index(1).Elem().String(), - Port: int(contact.Index(2).Elem().Int()), - }) - } - return -} - -func getArgs(argsInt interface{}) (args []string) { +func getArgs(argsInt interface{}) []string { + args := []string{} if reflect.TypeOf(argsInt).Kind() == reflect.Slice { v := reflect.ValueOf(argsInt) for i := 0; i < v.Len(); i++ { args = append(args, cast.ToString(v.Index(i).Interface())) } } - return + return args } func argsToString(args []string) string { diff --git a/dht/dht_test.go b/dht/dht_test.go index 8d5d9a8..a3589b2 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -1,13 +1,17 @@ package dht import ( + "encoding/hex" "testing" "time" + "github.com/davecgh/go-spew/spew" + log "github.com/sirupsen/logrus" "github.com/zeebo/bencode" ) func TestPing(t *testing.T) { + log.SetLevel(log.DebugLevel) dhtNodeID := newRandomBitmap() testNodeID := newRandomBitmap() @@ -111,17 +115,41 @@ func TestStore(t *testing.T) { go dht.runHandler() messageID := newRandomBitmap().RawString() - idToStore := newRandomBitmap().RawString() + blobHashToStore := newRandomBitmap().RawString() - data, err := bencode.EncodeBytes(map[string]interface{}{ - headerTypeField: requestType, - headerMessageIDField: messageID, - headerNodeIDField: testNodeID.RawString(), - headerPayloadField: "store", - headerArgsField: []string{idToStore}, - }) + storeRequest := Request{ + ID: messageID, + NodeID: testNodeID.RawString(), + Method: storeMethod, + StoreArgs: &storeArgs{ + BlobHash: blobHashToStore, + }, + } + storeRequest.StoreArgs.Value.Token = "arst" + storeRequest.StoreArgs.Value.LbryID = testNodeID.RawString() + storeRequest.StoreArgs.Value.Port = 9999 + + _ = "64 " + // start message + "313A30 693065" + // type: 0 + "313A31 3230 3A 6EB490B5788B63F0F7E6D92352024D0CBDEC2D3A" + // message id + "313A32 3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // node id + "313A33 35 3A 73746F7265" + // method + "313A34 6C" + // start args list + "3438 3A 3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468" + // block hash + "64" + // start value dict + "363A6C6272796964 3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // lbry id + "343A706F7274 69 33333333 65" + // port + "353A746F6B656E 3438 3A 17C2D8E1E48EF21567FE4AD5C8ED944B798D3B65AB58D0C9122AD6587D1B5FED472EA2CB12284CEFA1C21EFF302322BD" + // token + "65" + // end value dict + "3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // node id + "693065" + // self store (integer) + "65" + // end args list + "65" // end message + + data, err := bencode.EncodeBytes(storeRequest) if err != nil { - panic(err) + t.Error(err) + return } conn.toRead <- testUDPPacket{addr: conn.addr, data: data} @@ -191,3 +219,63 @@ func TestStore(t *testing.T) { } } } + +func TestFindNode(t *testing.T) { + dhtNodeID := newRandomBitmap() + + conn := newTestUDPConn("127.0.0.1:21217") + + dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht.conn = conn + dht.listen() + go dht.runHandler() + + data, _ := hex.DecodeString("64313a30693065313a3132303a2afdf2272981651a2c64e39ab7f04ec2d3b5d5d2313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33383a66696e644e6f6465313a346c34383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b6565") + + conn.toRead <- testUDPPacket{addr: conn.addr, data: data} + timer := time.NewTimer(3 * time.Second) + + select { + case <-timer.C: + t.Error("timeout") + case resp := <-conn.writes: + var response map[string]interface{} + err := bencode.DecodeBytes(resp.data, &response) + if err != nil { + t.Error(err) + return + } + + spew.Dump(response) + } +} + +func TestFindValue(t *testing.T) { + dhtNodeID := newRandomBitmap() + + conn := newTestUDPConn("127.0.0.1:21217") + + dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()}) + dht.conn = conn + dht.listen() + go dht.runHandler() + + data, _ := hex.DecodeString("6469306569306569316532303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f69326534383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b693365393a66696e6456616c75656934656c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565") + + conn.toRead <- testUDPPacket{addr: conn.addr, data: data} + timer := time.NewTimer(3 * time.Second) + + select { + case <-timer.C: + t.Error("timeout") + case resp := <-conn.writes: + var response map[string]interface{} + err := bencode.DecodeBytes(resp.data, &response) + if err != nil { + t.Error(err) + return + } + + spew.Dump(response) + } +} diff --git a/dht/message.go b/dht/message.go new file mode 100644 index 0000000..120872e --- /dev/null +++ b/dht/message.go @@ -0,0 +1,267 @@ +package dht + +import ( + "github.com/lbryio/errors.go" + + "github.com/spf13/cast" + "github.com/zeebo/bencode" +) + +const ( + pingMethod = "ping" + storeMethod = "store" + findNodeMethod = "findNode" + findValueMethod = "findValue" +) + +const ( + pingSuccessResponse = "pong" + storeSuccessResponse = "OK" +) + +const ( + requestType = 0 + responseType = 1 + errorType = 2 +) + +const ( + // these are strings because bencode requires bytestring keys + headerTypeField = "0" + headerMessageIDField = "1" // message id is 20 bytes long + headerNodeIDField = "2" // node id is 48 bytes long + headerPayloadField = "3" + headerArgsField = "4" +) + +type Message interface { + bencode.Marshaler + GetID() string +} + +type Request struct { + ID string + NodeID string + Method string + Args []string + StoreArgs *storeArgs +} + +func (r Request) GetID() string { return r.ID } +func (r Request) MarshalBencode() ([]byte, error) { + var args interface{} + if r.StoreArgs != nil { + args = r.StoreArgs + } else { + args = r.Args + } + return bencode.EncodeBytes(map[string]interface{}{ + headerTypeField: requestType, + headerMessageIDField: r.ID, + headerNodeIDField: r.NodeID, + headerPayloadField: r.Method, + headerArgsField: args, + }) +} + +func (r *Request) UnmarshalBencode(b []byte) error { + var raw struct { + ID string `bencode:"1"` + NodeID string `bencode:"2"` + Method string `bencode:"3"` + Args bencode.RawMessage `bencode:"4"` + } + err := bencode.DecodeBytes(b, &raw) + if err != nil { + return err + } + + r.ID = raw.ID + r.NodeID = raw.NodeID + r.Method = raw.Method + + if r.Method == storeMethod { + err = bencode.DecodeBytes(raw.Args, &r.StoreArgs) + } else { + err = bencode.DecodeBytes(raw.Args, &r.Args) + } + if err != nil { + return err + } + + return nil +} + +type storeArgs struct { + BlobHash string // 48 bytes + Value struct { + Token string `bencode:"token"` + LbryID string `bencode:"lbryid"` + Port int `bencode:"port"` + } + NodeID string // 48 bytes + SelfStore bool // this is an int on the wire +} + +func (s *storeArgs) MarshalBencode() ([]byte, error) { + encodedValue, err := bencode.EncodeString(s.Value) + if err != nil { + return nil, err + } + + selfStoreStr := 0 + if s.SelfStore { + selfStoreStr = 1 + } + + return bencode.EncodeBytes([]interface{}{ + s.BlobHash, + bencode.RawMessage(encodedValue), + s.NodeID, + selfStoreStr, + }) +} + +func (s *storeArgs) UnmarshalBencode(b []byte) error { + var argsInt []bencode.RawMessage + err := bencode.DecodeBytes(b, &argsInt) + if err != nil { + return err + } + + if len(argsInt) != 4 { + return errors.Err("unexpected number of fields for store args. got " + cast.ToString(len(argsInt))) + } + + err = bencode.DecodeBytes(argsInt[0], &s.BlobHash) + if err != nil { + return errors.Err(err) + } + + err = bencode.DecodeBytes(argsInt[1], &s.Value) + if err != nil { + return errors.Err(err) + } + + err = bencode.DecodeBytes(argsInt[2], &s.NodeID) + if err != nil { + return errors.Err(err) + } + + var selfStore int + err = bencode.DecodeBytes(argsInt[3], &selfStore) + if err != nil { + return errors.Err(err) + } + if selfStore == 0 { + s.SelfStore = false + } else if selfStore == 1 { + s.SelfStore = true + } else { + return errors.Err("selfstore must be 1 or 0") + } + + return nil +} + +type findNodeDatum struct { + ID bitmap + IP string + Port int +} + +func (f *findNodeDatum) UnmarshalBencode(b []byte) error { + var contact []bencode.RawMessage + err := bencode.DecodeBytes(b, &contact) + if err != nil { + return err + } + + if len(contact) != 3 { + return errors.Err("invalid-sized contact") + } + + err = bencode.DecodeBytes(contact[0], &f.ID) + if err != nil { + return err + } + err = bencode.DecodeBytes(contact[1], &f.IP) + if err != nil { + return err + } + err = bencode.DecodeBytes(contact[2], &f.Port) + if err != nil { + return err + } + + return nil +} + +type Response struct { + ID string + NodeID string + Data string + FindNodeData []findNodeDatum +} + +func (r Response) GetID() string { return r.ID } + +func (r Response) MarshalBencode() ([]byte, error) { + data := map[string]interface{}{ + headerTypeField: responseType, + headerMessageIDField: r.ID, + headerNodeIDField: r.NodeID, + } + if r.Data != "" { + data[headerPayloadField] = r.Data + } else { + var nodes []interface{} + for _, n := range r.FindNodeData { + nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port}) + } + data[headerPayloadField] = nodes + } + + return bencode.EncodeBytes(data) +} + +func (r *Response) UnmarshalBencode(b []byte) error { + var raw struct { + ID string `bencode:"1"` + NodeID string `bencode:"2"` + Data bencode.RawMessage `bencode:"2"` + } + err := bencode.DecodeBytes(b, &raw) + if err != nil { + return err + } + + r.ID = raw.ID + r.NodeID = raw.NodeID + + err = bencode.DecodeBytes(raw.Data, &r.Data) + if err != nil { + err = bencode.DecodeBytes(raw.Data, r.FindNodeData) + if err != nil { + return err + } + } +} + +type Error struct { + ID string + NodeID string + Response []string + ExceptionType string +} + +func (e Error) GetID() string { return e.ID } +func (e Error) MarshalBencode() ([]byte, error) { + return bencode.EncodeBytes(map[string]interface{}{ + headerTypeField: errorType, + headerMessageIDField: e.ID, + headerNodeIDField: e.NodeID, + headerPayloadField: e.ExceptionType, + headerArgsField: e.Response, + }) +} diff --git a/dht/message_test.go b/dht/message_test.go new file mode 100644 index 0000000..f2ec466 --- /dev/null +++ b/dht/message_test.go @@ -0,0 +1,75 @@ +package dht + +import ( + "encoding/hex" + "reflect" + "strconv" + "strings" + "testing" + + log "github.com/sirupsen/logrus" + "github.com/zeebo/bencode" +) + +func TestBencodeDecodeStoreArgs(t *testing.T) { + log.SetLevel(log.DebugLevel) + + blobHash := "3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468" + lbryID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + port := hex.EncodeToString([]byte("3333")) + token := "17C2D8E1E48EF21567FE4AD5C8ED944B798D3B65AB58D0C9122AD6587D1B5FED472EA2CB12284CEFA1C21EFF302322BD" + nodeID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + selfStore := hex.EncodeToString([]byte("1")) + + raw := "6C" + // start args list + "3438 3A " + blobHash + // blob hash + "64" + // start value dict + "363A6C6272796964 3438 3A " + lbryID + // lbry id + "343A706F7274 69 " + port + " 65" + // port + "353A746F6B656E 3438 3A " + token + // token + "65" + // end value dict + "3438 3A " + nodeID + // node id + "69 " + selfStore + " 65" + // self store (integer) + "65" // end args list + + raw = strings.ToLower(strings.Replace(raw, " ", "", -1)) + + data, err := hex.DecodeString(raw) + if err != nil { + t.Error(err) + return + } + + storeArgs := &storeArgs{} + err = bencode.DecodeBytes(data, storeArgs) + if err != nil { + t.Error(err) + } + + if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) { + t.Error("blob hash mismatch") + } + if hex.EncodeToString([]byte(storeArgs.Value.LbryID)) != strings.ToLower(lbryID) { + t.Error("lbryid mismatch") + } + if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port { + t.Error("port mismatch") + } + if hex.EncodeToString([]byte(storeArgs.Value.Token)) != strings.ToLower(token) { + t.Error("token mismatch") + } + if hex.EncodeToString([]byte(storeArgs.NodeID)) != strings.ToLower(nodeID) { + t.Error("node id mismatch") + } + if !storeArgs.SelfStore { + t.Error("selfStore mismatch") + } + + reencoded, err := bencode.EncodeBytes(storeArgs) + if err != nil { + t.Error(err) + } else if !reflect.DeepEqual(reencoded, data) { + t.Error("reencoded data does not match original") + //spew.Dump(reencoded, data) + } +} diff --git a/dht/messages.go b/dht/messages.go deleted file mode 100644 index 615d8b0..0000000 --- a/dht/messages.go +++ /dev/null @@ -1,103 +0,0 @@ -package dht - -import "github.com/zeebo/bencode" - -const ( - pingMethod = "ping" - storeMethod = "store" - findNodeMethod = "findNode" - findValueMethod = "findValue" -) - -const ( - pingSuccessResponse = "pong" - storeSuccessResponse = "OK" -) - -const ( - requestType = 0 - responseType = 1 - errorType = 2 -) - -const ( - // these are strings because bencode requires bytestring keys - headerTypeField = "0" - headerMessageIDField = "1" - headerNodeIDField = "2" - headerPayloadField = "3" - headerArgsField = "4" -) - -type Message interface { - GetID() string - Encode() ([]byte, error) -} - -type Request struct { - ID string - NodeID string - Method string - Args []string -} - -func (r Request) GetID() string { return r.ID } -func (r Request) Encode() ([]byte, error) { - return bencode.EncodeBytes(map[string]interface{}{ - headerTypeField: requestType, - headerMessageIDField: r.ID, - headerNodeIDField: r.NodeID, - headerPayloadField: r.Method, - headerArgsField: r.Args, - }) -} - -type findNodeDatum struct { - ID string - IP string - Port int -} -type Response struct { - ID string - NodeID string - Data string - FindNodeData []findNodeDatum -} - -func (r Response) GetID() string { return r.ID } -func (r Response) Encode() ([]byte, error) { - data := map[string]interface{}{ - headerTypeField: responseType, - headerMessageIDField: r.ID, - headerNodeIDField: r.NodeID, - } - if r.Data != "" { - data[headerPayloadField] = r.Data - } else { - var nodes []interface{} - for _, n := range r.FindNodeData { - nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port}) - } - data[headerPayloadField] = nodes - } - - return bencode.EncodeBytes(data) -} - -type Error struct { - ID string - NodeID string - Response []string - ExceptionType string -} - -func (e Error) GetID() string { return e.ID } -func (e Error) Encode() ([]byte, error) { - return bencode.EncodeBytes(map[string]interface{}{ - headerTypeField: errorType, - headerMessageIDField: e.ID, - headerNodeIDField: e.NodeID, - headerPayloadField: e.ExceptionType, - headerArgsField: e.Response, - }) -} diff --git a/dht/store.go b/dht/store.go new file mode 100644 index 0000000..c7fbeea --- /dev/null +++ b/dht/store.go @@ -0,0 +1,48 @@ +package dht + +import ( + "sync" + "time" +) + +type peer struct { + node *Node + lastPublished time.Time + originallyPublished time.Time + originalPublisherID bitmap +} + +type peerStore struct { + data map[bitmap][]peer + lock sync.RWMutex +} + +func newPeerStore() *peerStore { + return &peerStore{ + data: map[bitmap][]peer{}, + } +} + +func (s *peerStore) Insert(key bitmap, node *Node, lastPublished, originallyPublished time.Time, originaPublisherID bitmap) { + s.lock.Lock() + defer s.lock.Unlock() + newPeer := peer{node: node, lastPublished: lastPublished, originallyPublished: originallyPublished, originalPublisherID: originaPublisherID} + _, ok := s.data[key] + if !ok { + s.data[key] = []peer{newPeer} + } else { + s.data[key] = append(s.data[key], newPeer) + } +} + +func (s *peerStore) GetNodes(key bitmap) []*Node { + s.lock.RLock() + defer s.lock.RUnlock() + nodes := []*Node{} + if peers, ok := s.data[key]; ok { + for _, p := range peers { + nodes = append(nodes, p.node) + } + } + return nodes +}