diff --git a/dht/bitmap.go b/dht/bitmap.go index e8d9caf..54dc267 100644 --- a/dht/bitmap.go +++ b/dht/bitmap.go @@ -5,21 +5,22 @@ import ( "encoding/hex" "strconv" + "github.com/lbryio/errors.go" "github.com/lyoshenka/bencode" ) type bitmap [nodeIDLength]byte func (b bitmap) RawString() string { - return string(b[0:nodeIDLength]) + return string(b[:]) } func (b bitmap) Hex() string { - return hex.EncodeToString(b[0:nodeIDLength]) + return hex.EncodeToString(b[:]) } func (b bitmap) HexShort() string { - return hex.EncodeToString(b[0:nodeIDLength])[:8] + return hex.EncodeToString(b[:4]) } func (b bitmap) Equals(other bitmap) bool { @@ -66,6 +67,9 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error { if err != nil { return err } + if len(str) != nodeIDLength { + return errors.Err("invalid node ID length") + } copy(b[:], str) return nil } diff --git a/dht/dht.go b/dht/dht.go index bd56323..63cc1ad 100644 --- a/dht/dht.go +++ b/dht/dht.go @@ -203,7 +203,7 @@ func (dht *DHT) join() { } // now call iterativeFind on yourself - _, err := dht.FindNodes(dht.node.id) + _, _, err := dht.Get(dht.node.id) if err != nil { log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error()) } @@ -260,16 +260,7 @@ func printState(dht *DHT) { } } -func (dht *DHT) FindNodes(hash bitmap) ([]Node, error) { - nf := newNodeFinder(dht, hash, false) - res, err := nf.Find() - if err != nil { - return nil, err - } - return res.Nodes, nil -} - -func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) { +func (dht *DHT) Get(hash bitmap) ([]Node, bool, error) { nf := newNodeFinder(dht, hash, true) res, err := nf.Find() if err != nil { @@ -278,6 +269,30 @@ func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) { return res.Nodes, res.Found, nil } +func (dht *DHT) Put(hash bitmap) error { + nf := newNodeFinder(dht, hash, false) + res, err := nf.Find() + if err != nil { + return err + } + + for _, node := range res.Nodes { + send(dht, node.Addr(), &Request{ + Method: storeMethod, + StoreArgs: &storeArgs{ + BlobHash: hash.RawString(), + Value: storeArgsValue{ + Token: "", + LbryID: dht.node.id, + Port: dht.node.port, + }, + }, + }) + } + + return nil +} + type nodeFinder struct { findValue bool // true if we're using findValue target bitmap diff --git a/dht/dht_test.go b/dht/dht_test.go index 179da59..fd4c76c 100644 --- a/dht/dht_test.go +++ b/dht/dht_test.go @@ -4,8 +4,6 @@ import ( "net" "testing" "time" - - "github.com/davecgh/go-spew/spew" ) func TestDHT_FindNodes(t *testing.T) { @@ -22,7 +20,7 @@ func TestDHT_FindNodes(t *testing.T) { go dht1.Start() defer dht1.Shutdown() - time.Sleep(1 * time.Second) + time.Sleep(1 * time.Second) // give dhts a chance to connect dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}}) if err != nil { @@ -42,13 +40,15 @@ func TestDHT_FindNodes(t *testing.T) { time.Sleep(1 * time.Second) // give dhts a chance to connect - foundNodes, err := dht3.FindNodes(id2) + foundNodes, found, err := dht3.Get(newRandomBitmap()) if err != nil { t.Fatal(err) } - spew.Dump(foundNodes) + if found { + t.Fatal("something was found, but it should not have been") + } if len(foundNodes) != 2 { t.Errorf("expected 2 nodes, found %d", len(foundNodes)) @@ -74,7 +74,7 @@ func TestDHT_FindNodes(t *testing.T) { } } -func TestDHT_FindValue(t *testing.T) { +func TestDHT_Get(t *testing.T) { id1 := newRandomBitmap() id2 := newRandomBitmap() id3 := newRandomBitmap() @@ -111,7 +111,7 @@ func TestDHT_FindValue(t *testing.T) { nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678} dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind) - foundNodes, found, err := dht3.FindValue(nodeToFind.id) + foundNodes, found, err := dht3.Get(nodeToFind.id) if err != nil { t.Fatal(err) } diff --git a/dht/message.go b/dht/message.go index cdc1dea..77ef27a 100644 --- a/dht/message.go +++ b/dht/message.go @@ -1,7 +1,9 @@ package dht import ( + "crypto/rand" "encoding/hex" + "reflect" "strings" "github.com/lbryio/errors.go" @@ -41,9 +43,39 @@ type Message interface { bencode.Marshaler } +type messageID [messageIDLength]byte + +func (m messageID) HexShort() string { + return hex.EncodeToString(m[:])[:8] +} + +func (m *messageID) UnmarshalBencode(encoded []byte) error { + var str string + err := bencode.DecodeBytes(encoded, &str) + if err != nil { + return err + } + copy(m[:], str) + return nil +} + +func (m messageID) MarshalBencode() ([]byte, error) { + str := string(m[:]) + return bencode.EncodeBytes(str) +} + +func newMessageID() messageID { + var m messageID + _, err := rand.Read(m[:]) + if err != nil { + panic(err) + } + return m +} + type Request struct { - ID string - NodeID string + ID messageID + NodeID bitmap Method string Args []string StoreArgs *storeArgs @@ -67,8 +99,8 @@ func (r Request) MarshalBencode() ([]byte, error) { func (r *Request) UnmarshalBencode(b []byte) error { var raw struct { - ID string `bencode:"1"` - NodeID string `bencode:"2"` + ID messageID `bencode:"1"` + NodeID bitmap `bencode:"2"` Method string `bencode:"3"` Args bencode.RawMessage `bencode:"4"` } @@ -94,13 +126,15 @@ func (r *Request) UnmarshalBencode(b []byte) error { return nil } +type storeArgsValue struct { + Token string `bencode:"token"` + LbryID bitmap `bencode:"lbryid"` + Port int `bencode:"port"` +} + type storeArgs struct { - BlobHash string - Value struct { - Token string `bencode:"token"` - LbryID string `bencode:"lbryid"` - Port int `bencode:"port"` - } + BlobHash string + Value storeArgsValue NodeID bitmap SelfStore bool // this is an int on the wire } @@ -167,8 +201,8 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error { } type Response struct { - ID string - NodeID string + ID messageID + NodeID bitmap Data string FindNodeData []Node FindValueKey string @@ -219,8 +253,8 @@ func (r Response) MarshalBencode() ([]byte, error) { func (r *Response) UnmarshalBencode(b []byte) error { var raw struct { - ID string `bencode:"1"` - NodeID string `bencode:"2"` + ID messageID `bencode:"1"` + NodeID bitmap `bencode:"2"` Data bencode.RawMessage `bencode:"3"` } err := bencode.DecodeBytes(b, &raw) @@ -269,10 +303,10 @@ func (r *Response) UnmarshalBencode(b []byte) error { } type Error struct { - ID string - NodeID string - Response []string + ID messageID + NodeID bitmap ExceptionType string + Response []string } func (e Error) MarshalBencode() ([]byte, error) { @@ -284,3 +318,29 @@ func (e Error) MarshalBencode() ([]byte, error) { headerArgsField: e.Response, }) } + +func (e *Error) UnmarshalBencode(b []byte) error { + var raw struct { + ID messageID `bencode:"1"` + NodeID bitmap `bencode:"2"` + ExceptionType string `bencode:"3"` + Args interface{} `bencode:"4"` + } + err := bencode.DecodeBytes(b, &raw) + if err != nil { + return err + } + + e.ID = raw.ID + e.NodeID = raw.NodeID + e.ExceptionType = raw.ExceptionType + + if reflect.TypeOf(raw.Args).Kind() == reflect.Slice { + v := reflect.ValueOf(raw.Args) + for i := 0; i < v.Len(); i++ { + e.Response = append(e.Response, cast.ToString(v.Index(i).Interface())) + } + } + + return nil +} diff --git a/dht/message_test.go b/dht/message_test.go index de31e6e..55ebf2a 100644 --- a/dht/message_test.go +++ b/dht/message_test.go @@ -48,7 +48,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) { if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) { t.Error("blob hash mismatch") } - if hex.EncodeToString([]byte(storeArgs.Value.LbryID)) != strings.ToLower(lbryID) { + if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) { t.Error("lbryid mismatch") } if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port { @@ -76,7 +76,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) { func TestBencodeFindNodesResponse(t *testing.T) { res := Response{ ID: newMessageID(), - NodeID: newRandomBitmap().RawString(), + NodeID: newRandomBitmap(), FindNodeData: []Node{ {id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, {id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765}, @@ -100,7 +100,7 @@ func TestBencodeFindNodesResponse(t *testing.T) { func TestBencodeFindValueResponse(t *testing.T) { res := Response{ ID: newMessageID(), - NodeID: newRandomBitmap().RawString(), + NodeID: newRandomBitmap(), FindValueKey: newRandomBitmap().RawString(), FindNodeData: []Node{ {id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, @@ -125,8 +125,8 @@ func compareResponses(t *testing.T, res, res2 Response) { if res.ID != res2.ID { t.Errorf("expected ID %s, got %s", res.ID, res2.ID) } - if res.NodeID != res2.NodeID { - t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID) + if !res.NodeID.Equals(res2.NodeID) { + t.Errorf("expected NodeID %s, got %s", res.NodeID.Hex(), res2.NodeID.Hex()) } if res.Data != res2.Data { t.Errorf("expected Data %s, got %s", res.Data, res2.Data) diff --git a/dht/rpc.go b/dht/rpc.go index 286f60b..fb8df4c 100644 --- a/dht/rpc.go +++ b/dht/rpc.go @@ -1,28 +1,16 @@ package dht import ( - "crypto/rand" "encoding/hex" "net" - "reflect" "strings" "time" "github.com/davecgh/go-spew/spew" "github.com/lyoshenka/bencode" log "github.com/sirupsen/logrus" - "github.com/spf13/cast" ) -func newMessageID() string { - buf := make([]byte, messageIDLength) - _, err := rand.Read(buf) - if err != nil { - panic(err) - } - return string(buf) -} - // handlePacket handles packets received from udp. func handlePacket(dht *DHT, pkt packet) { //log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data)) @@ -48,7 +36,7 @@ func handlePacket(dht *DHT, pkt packet) { log.Errorln(err) return } - log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) + log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args)) handleRequest(dht, pkt.raddr, request) case responseType: @@ -58,17 +46,17 @@ func handlePacket(dht *DHT, pkt packet) { log.Errorln(err) return } - log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.ArgsDebug()) + log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug()) handleResponse(dht, pkt.raddr, response) case errorType: - e := Error{ - ID: data[headerMessageIDField].(string), - NodeID: data[headerNodeIDField].(string), - ExceptionType: data[headerPayloadField].(string), - Response: getArgs(data[headerArgsField]), + e := Error{} + err = bencode.DecodeBytes(pkt.data, &e) + if err != nil { + log.Errorln(err) + return } - log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType) + log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType) handleError(dht, pkt.raddr, e) default: @@ -79,14 +67,14 @@ func handlePacket(dht *DHT, pkt packet) { // handleRequest handles the requests received from udp. func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { - if request.NodeID == dht.node.id.RawString() { + if request.NodeID.Equals(dht.node.id) { log.Warn("ignoring self-request") return } switch request.Method { case pingMethod: - send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse}) + send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse}) case storeMethod: if request.StoreArgs.BlobHash == "" { log.Errorln("blobhash is empty") @@ -95,7 +83,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { // TODO: we should be sending the IP in the request, not just using the sender's IP // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) - send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse}) + send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse}) case findNodeMethod: if len(request.Args) < 1 { log.Errorln("nothing to find") @@ -117,7 +105,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { } if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 { - response := Response{ID: request.ID, NodeID: dht.node.id.RawString()} + response := Response{ID: request.ID, NodeID: dht.node.id} response.FindValueKey = request.Args[0] response.FindNodeData = nodes send(dht, addr, response) @@ -131,7 +119,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { return } - node := Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port} + node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port} dht.rt.Update(node) } @@ -139,7 +127,7 @@ func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) { nodeID := newBitmapFromString(request.Args[0]) closestNodes := dht.rt.GetClosest(nodeID, bucketSize) if len(closestNodes) > 0 { - response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} + response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))} for i, n := range closestNodes { response.FindNodeData[i] = n } @@ -156,14 +144,14 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) { tx.res <- &response } - node := Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port} + node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port} dht.rt.Update(node) } // handleError handles errors received from udp. func handleError(dht *DHT, addr *net.UDPAddr, e Error) { spew.Dump(e) - node := Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port} + node := Node{id: e.NodeID, ip: addr.IP, port: addr.Port} dht.rt.Update(node) } @@ -176,10 +164,10 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error { if req, ok := data.(Request); ok { log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)", - dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args)) + dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, argsToString(req.Args)) } else if res, ok := data.(Response); ok { log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s", - dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug()) + dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug()) } else { log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data)) } @@ -190,17 +178,6 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error { return err } -func getArgs(argsInt interface{}) []string { - var args []string - if reflect.TypeOf(argsInt).Kind() == reflect.Slice { - v := reflect.ValueOf(argsInt) - for i := 0; i < v.Len(); i++ { - args = append(args, cast.ToString(v.Index(i).Interface())) - } - } - return args -} - func argsToString(args []string) string { argsCopy := make([]string, len(args)) copy(argsCopy, args) diff --git a/dht/rpc_test.go b/dht/rpc_test.go index 70df70a..54414c3 100644 --- a/dht/rpc_test.go +++ b/dht/rpc_test.go @@ -151,7 +151,7 @@ func TestPing(t *testing.T) { rMessageID, ok := response[headerMessageIDField].(string) if !ok { t.Error("message ID is not a string") - } else if rMessageID != messageID { + } else if rMessageID != string(messageID[:]) { t.Error("unexpected message ID") } } @@ -203,16 +203,18 @@ func TestStore(t *testing.T) { storeRequest := Request{ ID: messageID, - NodeID: testNodeID.RawString(), + NodeID: testNodeID, Method: storeMethod, StoreArgs: &storeArgs{ BlobHash: blobHashToStore, - NodeID: testNodeID, + Value: storeArgsValue{ + Token: "arst", + LbryID: testNodeID, + Port: 9999, + }, + NodeID: testNodeID, }, } - storeRequest.StoreArgs.Value.Token = "arst" - storeRequest.StoreArgs.Value.LbryID = testNodeID.RawString() - storeRequest.StoreArgs.Value.Port = 9999 _ = "64 " + // start message "313A30 693065" + // type: 0 @@ -305,7 +307,7 @@ func TestFindNode(t *testing.T) { request := Request{ ID: messageID, - NodeID: testNodeID.RawString(), + NodeID: testNodeID, Method: findNodeMethod, Args: []string{blobHashToFind}, } @@ -390,7 +392,7 @@ func TestFindValueExisting(t *testing.T) { request := Request{ ID: messageID, - NodeID: testNodeID.RawString(), + NodeID: testNodeID, Method: findValueMethod, Args: []string{valueToFind}, } @@ -468,7 +470,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) { request := Request{ ID: messageID, - NodeID: testNodeID.RawString(), + NodeID: testNodeID, Method: findValueMethod, Args: []string{valueToFind}, } @@ -517,7 +519,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) { verifyContacts(t, contacts, nodes) } -func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) { +func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) { if len(resp) != 4 { t.Errorf("expected 4 response fields, got %d", len(resp)) } @@ -541,7 +543,7 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod rMessageID, ok := resp[headerMessageIDField].(string) if !ok { t.Error("message ID is not a string") - } else if rMessageID != messageID { + } else if rMessageID != string(id[:]) { t.Error("unexpected message ID") } if len(rMessageID) != messageIDLength { diff --git a/dht/transaction_manager.go b/dht/transaction_manager.go index 061b2ae..e8799aa 100644 --- a/dht/transaction_manager.go +++ b/dht/transaction_manager.go @@ -19,7 +19,7 @@ type transaction struct { // transactionManager represents the manager of transactions. type transactionManager struct { lock *sync.RWMutex - transactions map[string]*transaction + transactions map[messageID]*transaction dht *DHT } @@ -27,7 +27,7 @@ type transactionManager struct { func newTransactionManager(dht *DHT) *transactionManager { return &transactionManager{ lock: &sync.RWMutex{}, - transactions: make(map[string]*transaction), + transactions: make(map[messageID]*transaction), dht: dht, } } @@ -40,14 +40,14 @@ func (tm *transactionManager) insert(trans *transaction) { } // delete removes a transaction from transactionManager. -func (tm *transactionManager) delete(transID string) { +func (tm *transactionManager) delete(id messageID) { tm.lock.Lock() defer tm.lock.Unlock() - delete(tm.transactions, transID) + delete(tm.transactions, id) } // find transaction for id. optionally ensure that addr matches node from transaction -func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction { +func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction { tm.lock.RLock() defer tm.lock.RUnlock() @@ -73,7 +73,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req defer close(ch) req.ID = newMessageID() - req.NodeID = tm.dht.node.id.RawString() + req.NodeID = tm.dht.node.id trans := &transaction{ node: node, req: req,