proper types for some IDs

This commit is contained in:
Alex Grintsvayg 2018-04-03 13:38:01 -04:00
parent 13f991852b
commit 868e243afc
8 changed files with 159 additions and 101 deletions

View file

@ -5,21 +5,22 @@ import (
"encoding/hex" "encoding/hex"
"strconv" "strconv"
"github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
) )
type bitmap [nodeIDLength]byte type bitmap [nodeIDLength]byte
func (b bitmap) RawString() string { func (b bitmap) RawString() string {
return string(b[0:nodeIDLength]) return string(b[:])
} }
func (b bitmap) Hex() string { func (b bitmap) Hex() string {
return hex.EncodeToString(b[0:nodeIDLength]) return hex.EncodeToString(b[:])
} }
func (b bitmap) HexShort() string { func (b bitmap) HexShort() string {
return hex.EncodeToString(b[0:nodeIDLength])[:8] return hex.EncodeToString(b[:4])
} }
func (b bitmap) Equals(other bitmap) bool { func (b bitmap) Equals(other bitmap) bool {
@ -66,6 +67,9 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error {
if err != nil { if err != nil {
return err return err
} }
if len(str) != nodeIDLength {
return errors.Err("invalid node ID length")
}
copy(b[:], str) copy(b[:], str)
return nil return nil
} }

View file

@ -203,7 +203,7 @@ func (dht *DHT) join() {
} }
// now call iterativeFind on yourself // now call iterativeFind on yourself
_, err := dht.FindNodes(dht.node.id) _, _, err := dht.Get(dht.node.id)
if err != nil { if err != nil {
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error()) log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
} }
@ -260,16 +260,7 @@ func printState(dht *DHT) {
} }
} }
func (dht *DHT) FindNodes(hash bitmap) ([]Node, error) { func (dht *DHT) Get(hash bitmap) ([]Node, bool, error) {
nf := newNodeFinder(dht, hash, false)
res, err := nf.Find()
if err != nil {
return nil, err
}
return res.Nodes, nil
}
func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
nf := newNodeFinder(dht, hash, true) nf := newNodeFinder(dht, hash, true)
res, err := nf.Find() res, err := nf.Find()
if err != nil { if err != nil {
@ -278,6 +269,30 @@ func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
return res.Nodes, res.Found, nil return res.Nodes, res.Found, nil
} }
func (dht *DHT) Put(hash bitmap) error {
nf := newNodeFinder(dht, hash, false)
res, err := nf.Find()
if err != nil {
return err
}
for _, node := range res.Nodes {
send(dht, node.Addr(), &Request{
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: hash.RawString(),
Value: storeArgsValue{
Token: "",
LbryID: dht.node.id,
Port: dht.node.port,
},
},
})
}
return nil
}
type nodeFinder struct { type nodeFinder struct {
findValue bool // true if we're using findValue findValue bool // true if we're using findValue
target bitmap target bitmap

View file

@ -4,8 +4,6 @@ import (
"net" "net"
"testing" "testing"
"time" "time"
"github.com/davecgh/go-spew/spew"
) )
func TestDHT_FindNodes(t *testing.T) { func TestDHT_FindNodes(t *testing.T) {
@ -22,7 +20,7 @@ func TestDHT_FindNodes(t *testing.T) {
go dht1.Start() go dht1.Start()
defer dht1.Shutdown() defer dht1.Shutdown()
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second) // give dhts a chance to connect
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}}) dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
if err != nil { if err != nil {
@ -42,13 +40,15 @@ func TestDHT_FindNodes(t *testing.T) {
time.Sleep(1 * time.Second) // give dhts a chance to connect time.Sleep(1 * time.Second) // give dhts a chance to connect
foundNodes, err := dht3.FindNodes(id2) foundNodes, found, err := dht3.Get(newRandomBitmap())
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
spew.Dump(foundNodes) if found {
t.Fatal("something was found, but it should not have been")
}
if len(foundNodes) != 2 { if len(foundNodes) != 2 {
t.Errorf("expected 2 nodes, found %d", len(foundNodes)) t.Errorf("expected 2 nodes, found %d", len(foundNodes))
@ -74,7 +74,7 @@ func TestDHT_FindNodes(t *testing.T) {
} }
} }
func TestDHT_FindValue(t *testing.T) { func TestDHT_Get(t *testing.T) {
id1 := newRandomBitmap() id1 := newRandomBitmap()
id2 := newRandomBitmap() id2 := newRandomBitmap()
id3 := newRandomBitmap() id3 := newRandomBitmap()
@ -111,7 +111,7 @@ func TestDHT_FindValue(t *testing.T) {
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678} nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind) dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
foundNodes, found, err := dht3.FindValue(nodeToFind.id) foundNodes, found, err := dht3.Get(nodeToFind.id)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View file

@ -1,7 +1,9 @@
package dht package dht
import ( import (
"crypto/rand"
"encoding/hex" "encoding/hex"
"reflect"
"strings" "strings"
"github.com/lbryio/errors.go" "github.com/lbryio/errors.go"
@ -41,9 +43,39 @@ type Message interface {
bencode.Marshaler bencode.Marshaler
} }
type messageID [messageIDLength]byte
func (m messageID) HexShort() string {
return hex.EncodeToString(m[:])[:8]
}
func (m *messageID) UnmarshalBencode(encoded []byte) error {
var str string
err := bencode.DecodeBytes(encoded, &str)
if err != nil {
return err
}
copy(m[:], str)
return nil
}
func (m messageID) MarshalBencode() ([]byte, error) {
str := string(m[:])
return bencode.EncodeBytes(str)
}
func newMessageID() messageID {
var m messageID
_, err := rand.Read(m[:])
if err != nil {
panic(err)
}
return m
}
type Request struct { type Request struct {
ID string ID messageID
NodeID string NodeID bitmap
Method string Method string
Args []string Args []string
StoreArgs *storeArgs StoreArgs *storeArgs
@ -67,8 +99,8 @@ func (r Request) MarshalBencode() ([]byte, error) {
func (r *Request) UnmarshalBencode(b []byte) error { func (r *Request) UnmarshalBencode(b []byte) error {
var raw struct { var raw struct {
ID string `bencode:"1"` ID messageID `bencode:"1"`
NodeID string `bencode:"2"` NodeID bitmap `bencode:"2"`
Method string `bencode:"3"` Method string `bencode:"3"`
Args bencode.RawMessage `bencode:"4"` Args bencode.RawMessage `bencode:"4"`
} }
@ -94,13 +126,15 @@ func (r *Request) UnmarshalBencode(b []byte) error {
return nil return nil
} }
type storeArgsValue struct {
Token string `bencode:"token"`
LbryID bitmap `bencode:"lbryid"`
Port int `bencode:"port"`
}
type storeArgs struct { type storeArgs struct {
BlobHash string BlobHash string
Value struct { Value storeArgsValue
Token string `bencode:"token"`
LbryID string `bencode:"lbryid"`
Port int `bencode:"port"`
}
NodeID bitmap NodeID bitmap
SelfStore bool // this is an int on the wire SelfStore bool // this is an int on the wire
} }
@ -167,8 +201,8 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error {
} }
type Response struct { type Response struct {
ID string ID messageID
NodeID string NodeID bitmap
Data string Data string
FindNodeData []Node FindNodeData []Node
FindValueKey string FindValueKey string
@ -219,8 +253,8 @@ func (r Response) MarshalBencode() ([]byte, error) {
func (r *Response) UnmarshalBencode(b []byte) error { func (r *Response) UnmarshalBencode(b []byte) error {
var raw struct { var raw struct {
ID string `bencode:"1"` ID messageID `bencode:"1"`
NodeID string `bencode:"2"` NodeID bitmap `bencode:"2"`
Data bencode.RawMessage `bencode:"3"` Data bencode.RawMessage `bencode:"3"`
} }
err := bencode.DecodeBytes(b, &raw) err := bencode.DecodeBytes(b, &raw)
@ -269,10 +303,10 @@ func (r *Response) UnmarshalBencode(b []byte) error {
} }
type Error struct { type Error struct {
ID string ID messageID
NodeID string NodeID bitmap
Response []string
ExceptionType string ExceptionType string
Response []string
} }
func (e Error) MarshalBencode() ([]byte, error) { func (e Error) MarshalBencode() ([]byte, error) {
@ -284,3 +318,29 @@ func (e Error) MarshalBencode() ([]byte, error) {
headerArgsField: e.Response, headerArgsField: e.Response,
}) })
} }
func (e *Error) UnmarshalBencode(b []byte) error {
var raw struct {
ID messageID `bencode:"1"`
NodeID bitmap `bencode:"2"`
ExceptionType string `bencode:"3"`
Args interface{} `bencode:"4"`
}
err := bencode.DecodeBytes(b, &raw)
if err != nil {
return err
}
e.ID = raw.ID
e.NodeID = raw.NodeID
e.ExceptionType = raw.ExceptionType
if reflect.TypeOf(raw.Args).Kind() == reflect.Slice {
v := reflect.ValueOf(raw.Args)
for i := 0; i < v.Len(); i++ {
e.Response = append(e.Response, cast.ToString(v.Index(i).Interface()))
}
}
return nil
}

View file

@ -48,7 +48,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) { if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) {
t.Error("blob hash mismatch") t.Error("blob hash mismatch")
} }
if hex.EncodeToString([]byte(storeArgs.Value.LbryID)) != strings.ToLower(lbryID) { if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) {
t.Error("lbryid mismatch") t.Error("lbryid mismatch")
} }
if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port { if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port {
@ -76,7 +76,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
func TestBencodeFindNodesResponse(t *testing.T) { func TestBencodeFindNodesResponse(t *testing.T) {
res := Response{ res := Response{
ID: newMessageID(), ID: newMessageID(),
NodeID: newRandomBitmap().RawString(), NodeID: newRandomBitmap(),
FindNodeData: []Node{ FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, {id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
{id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765}, {id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
@ -100,7 +100,7 @@ func TestBencodeFindNodesResponse(t *testing.T) {
func TestBencodeFindValueResponse(t *testing.T) { func TestBencodeFindValueResponse(t *testing.T) {
res := Response{ res := Response{
ID: newMessageID(), ID: newMessageID(),
NodeID: newRandomBitmap().RawString(), NodeID: newRandomBitmap(),
FindValueKey: newRandomBitmap().RawString(), FindValueKey: newRandomBitmap().RawString(),
FindNodeData: []Node{ FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678}, {id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
@ -125,8 +125,8 @@ func compareResponses(t *testing.T, res, res2 Response) {
if res.ID != res2.ID { if res.ID != res2.ID {
t.Errorf("expected ID %s, got %s", res.ID, res2.ID) t.Errorf("expected ID %s, got %s", res.ID, res2.ID)
} }
if res.NodeID != res2.NodeID { if !res.NodeID.Equals(res2.NodeID) {
t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID) t.Errorf("expected NodeID %s, got %s", res.NodeID.Hex(), res2.NodeID.Hex())
} }
if res.Data != res2.Data { if res.Data != res2.Data {
t.Errorf("expected Data %s, got %s", res.Data, res2.Data) t.Errorf("expected Data %s, got %s", res.Data, res2.Data)

View file

@ -1,28 +1,16 @@
package dht package dht
import ( import (
"crypto/rand"
"encoding/hex" "encoding/hex"
"net" "net"
"reflect"
"strings" "strings"
"time" "time"
"github.com/davecgh/go-spew/spew" "github.com/davecgh/go-spew/spew"
"github.com/lyoshenka/bencode" "github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
) )
func newMessageID() string {
buf := make([]byte, messageIDLength)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
return string(buf)
}
// handlePacket handles packets received from udp. // handlePacket handles packets received from udp.
func handlePacket(dht *DHT, pkt packet) { func handlePacket(dht *DHT, pkt packet) {
//log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data)) //log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
@ -48,7 +36,7 @@ func handlePacket(dht *DHT, pkt packet) {
log.Errorln(err) log.Errorln(err)
return return
} }
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args)) log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args))
handleRequest(dht, pkt.raddr, request) handleRequest(dht, pkt.raddr, request)
case responseType: case responseType:
@ -58,17 +46,17 @@ func handlePacket(dht *DHT, pkt packet) {
log.Errorln(err) log.Errorln(err)
return return
} }
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.ArgsDebug()) log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
handleResponse(dht, pkt.raddr, response) handleResponse(dht, pkt.raddr, response)
case errorType: case errorType:
e := Error{ e := Error{}
ID: data[headerMessageIDField].(string), err = bencode.DecodeBytes(pkt.data, &e)
NodeID: data[headerNodeIDField].(string), if err != nil {
ExceptionType: data[headerPayloadField].(string), log.Errorln(err)
Response: getArgs(data[headerArgsField]), return
} }
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType) log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
handleError(dht, pkt.raddr, e) handleError(dht, pkt.raddr, e)
default: default:
@ -79,14 +67,14 @@ func handlePacket(dht *DHT, pkt packet) {
// handleRequest handles the requests received from udp. // handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) { func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
if request.NodeID == dht.node.id.RawString() { if request.NodeID.Equals(dht.node.id) {
log.Warn("ignoring self-request") log.Warn("ignoring self-request")
return return
} }
switch request.Method { switch request.Method {
case pingMethod: case pingMethod:
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse}) send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
case storeMethod: case storeMethod:
if request.StoreArgs.BlobHash == "" { if request.StoreArgs.BlobHash == "" {
log.Errorln("blobhash is empty") log.Errorln("blobhash is empty")
@ -95,7 +83,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
// TODO: we should be sending the IP in the request, not just using the sender's IP // TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ??? // TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port}) dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse}) send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
case findNodeMethod: case findNodeMethod:
if len(request.Args) < 1 { if len(request.Args) < 1 {
log.Errorln("nothing to find") log.Errorln("nothing to find")
@ -117,7 +105,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
} }
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 { if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()} response := Response{ID: request.ID, NodeID: dht.node.id}
response.FindValueKey = request.Args[0] response.FindValueKey = request.Args[0]
response.FindNodeData = nodes response.FindNodeData = nodes
send(dht, addr, response) send(dht, addr, response)
@ -131,7 +119,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
return return
} }
node := Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port} node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.Update(node) dht.rt.Update(node)
} }
@ -139,7 +127,7 @@ func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
nodeID := newBitmapFromString(request.Args[0]) nodeID := newBitmapFromString(request.Args[0])
closestNodes := dht.rt.GetClosest(nodeID, bucketSize) closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
if len(closestNodes) > 0 { if len(closestNodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))} response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes { for i, n := range closestNodes {
response.FindNodeData[i] = n response.FindNodeData[i] = n
} }
@ -156,14 +144,14 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
tx.res <- &response tx.res <- &response
} }
node := Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port} node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.Update(node) dht.rt.Update(node)
} }
// handleError handles errors received from udp. // handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, e Error) { func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
spew.Dump(e) spew.Dump(e)
node := Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port} node := Node{id: e.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.Update(node) dht.rt.Update(node)
} }
@ -176,10 +164,10 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok { if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)", log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args)) dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok { } else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s", log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug()) dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
} else { } else {
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data)) log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
} }
@ -190,17 +178,6 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
return err return err
} }
func getArgs(argsInt interface{}) []string {
var args []string
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
v := reflect.ValueOf(argsInt)
for i := 0; i < v.Len(); i++ {
args = append(args, cast.ToString(v.Index(i).Interface()))
}
}
return args
}
func argsToString(args []string) string { func argsToString(args []string) string {
argsCopy := make([]string, len(args)) argsCopy := make([]string, len(args))
copy(argsCopy, args) copy(argsCopy, args)

View file

@ -151,7 +151,7 @@ func TestPing(t *testing.T) {
rMessageID, ok := response[headerMessageIDField].(string) rMessageID, ok := response[headerMessageIDField].(string)
if !ok { if !ok {
t.Error("message ID is not a string") t.Error("message ID is not a string")
} else if rMessageID != messageID { } else if rMessageID != string(messageID[:]) {
t.Error("unexpected message ID") t.Error("unexpected message ID")
} }
} }
@ -203,16 +203,18 @@ func TestStore(t *testing.T) {
storeRequest := Request{ storeRequest := Request{
ID: messageID, ID: messageID,
NodeID: testNodeID.RawString(), NodeID: testNodeID,
Method: storeMethod, Method: storeMethod,
StoreArgs: &storeArgs{ StoreArgs: &storeArgs{
BlobHash: blobHashToStore, BlobHash: blobHashToStore,
Value: storeArgsValue{
Token: "arst",
LbryID: testNodeID,
Port: 9999,
},
NodeID: testNodeID, NodeID: testNodeID,
}, },
} }
storeRequest.StoreArgs.Value.Token = "arst"
storeRequest.StoreArgs.Value.LbryID = testNodeID.RawString()
storeRequest.StoreArgs.Value.Port = 9999
_ = "64 " + // start message _ = "64 " + // start message
"313A30 693065" + // type: 0 "313A30 693065" + // type: 0
@ -305,7 +307,7 @@ func TestFindNode(t *testing.T) {
request := Request{ request := Request{
ID: messageID, ID: messageID,
NodeID: testNodeID.RawString(), NodeID: testNodeID,
Method: findNodeMethod, Method: findNodeMethod,
Args: []string{blobHashToFind}, Args: []string{blobHashToFind},
} }
@ -390,7 +392,7 @@ func TestFindValueExisting(t *testing.T) {
request := Request{ request := Request{
ID: messageID, ID: messageID,
NodeID: testNodeID.RawString(), NodeID: testNodeID,
Method: findValueMethod, Method: findValueMethod,
Args: []string{valueToFind}, Args: []string{valueToFind},
} }
@ -468,7 +470,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
request := Request{ request := Request{
ID: messageID, ID: messageID,
NodeID: testNodeID.RawString(), NodeID: testNodeID,
Method: findValueMethod, Method: findValueMethod,
Args: []string{valueToFind}, Args: []string{valueToFind},
} }
@ -517,7 +519,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
verifyContacts(t, contacts, nodes) verifyContacts(t, contacts, nodes)
} }
func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) { func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) {
if len(resp) != 4 { if len(resp) != 4 {
t.Errorf("expected 4 response fields, got %d", len(resp)) t.Errorf("expected 4 response fields, got %d", len(resp))
} }
@ -541,7 +543,7 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
rMessageID, ok := resp[headerMessageIDField].(string) rMessageID, ok := resp[headerMessageIDField].(string)
if !ok { if !ok {
t.Error("message ID is not a string") t.Error("message ID is not a string")
} else if rMessageID != messageID { } else if rMessageID != string(id[:]) {
t.Error("unexpected message ID") t.Error("unexpected message ID")
} }
if len(rMessageID) != messageIDLength { if len(rMessageID) != messageIDLength {

View file

@ -19,7 +19,7 @@ type transaction struct {
// transactionManager represents the manager of transactions. // transactionManager represents the manager of transactions.
type transactionManager struct { type transactionManager struct {
lock *sync.RWMutex lock *sync.RWMutex
transactions map[string]*transaction transactions map[messageID]*transaction
dht *DHT dht *DHT
} }
@ -27,7 +27,7 @@ type transactionManager struct {
func newTransactionManager(dht *DHT) *transactionManager { func newTransactionManager(dht *DHT) *transactionManager {
return &transactionManager{ return &transactionManager{
lock: &sync.RWMutex{}, lock: &sync.RWMutex{},
transactions: make(map[string]*transaction), transactions: make(map[messageID]*transaction),
dht: dht, dht: dht,
} }
} }
@ -40,14 +40,14 @@ func (tm *transactionManager) insert(trans *transaction) {
} }
// delete removes a transaction from transactionManager. // delete removes a transaction from transactionManager.
func (tm *transactionManager) delete(transID string) { func (tm *transactionManager) delete(id messageID) {
tm.lock.Lock() tm.lock.Lock()
defer tm.lock.Unlock() defer tm.lock.Unlock()
delete(tm.transactions, transID) delete(tm.transactions, id)
} }
// find transaction for id. optionally ensure that addr matches node from transaction // find transaction for id. optionally ensure that addr matches node from transaction
func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction { func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction {
tm.lock.RLock() tm.lock.RLock()
defer tm.lock.RUnlock() defer tm.lock.RUnlock()
@ -73,7 +73,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
defer close(ch) defer close(ch)
req.ID = newMessageID() req.ID = newMessageID()
req.NodeID = tm.dht.node.id.RawString() req.NodeID = tm.dht.node.id
trans := &transaction{ trans := &transaction{
node: node, node: node,
req: req, req: req,