proper types for some IDs
This commit is contained in:
parent
13f991852b
commit
868e243afc
8 changed files with 159 additions and 101 deletions
|
@ -5,21 +5,22 @@ import (
|
|||
"encoding/hex"
|
||||
"strconv"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
"github.com/lyoshenka/bencode"
|
||||
)
|
||||
|
||||
type bitmap [nodeIDLength]byte
|
||||
|
||||
func (b bitmap) RawString() string {
|
||||
return string(b[0:nodeIDLength])
|
||||
return string(b[:])
|
||||
}
|
||||
|
||||
func (b bitmap) Hex() string {
|
||||
return hex.EncodeToString(b[0:nodeIDLength])
|
||||
return hex.EncodeToString(b[:])
|
||||
}
|
||||
|
||||
func (b bitmap) HexShort() string {
|
||||
return hex.EncodeToString(b[0:nodeIDLength])[:8]
|
||||
return hex.EncodeToString(b[:4])
|
||||
}
|
||||
|
||||
func (b bitmap) Equals(other bitmap) bool {
|
||||
|
@ -66,6 +67,9 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(str) != nodeIDLength {
|
||||
return errors.Err("invalid node ID length")
|
||||
}
|
||||
copy(b[:], str)
|
||||
return nil
|
||||
}
|
||||
|
|
37
dht/dht.go
37
dht/dht.go
|
@ -203,7 +203,7 @@ func (dht *DHT) join() {
|
|||
}
|
||||
|
||||
// now call iterativeFind on yourself
|
||||
_, err := dht.FindNodes(dht.node.id)
|
||||
_, _, err := dht.Get(dht.node.id)
|
||||
if err != nil {
|
||||
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
|
||||
}
|
||||
|
@ -260,16 +260,7 @@ func printState(dht *DHT) {
|
|||
}
|
||||
}
|
||||
|
||||
func (dht *DHT) FindNodes(hash bitmap) ([]Node, error) {
|
||||
nf := newNodeFinder(dht, hash, false)
|
||||
res, err := nf.Find()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res.Nodes, nil
|
||||
}
|
||||
|
||||
func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
|
||||
func (dht *DHT) Get(hash bitmap) ([]Node, bool, error) {
|
||||
nf := newNodeFinder(dht, hash, true)
|
||||
res, err := nf.Find()
|
||||
if err != nil {
|
||||
|
@ -278,6 +269,30 @@ func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
|
|||
return res.Nodes, res.Found, nil
|
||||
}
|
||||
|
||||
func (dht *DHT) Put(hash bitmap) error {
|
||||
nf := newNodeFinder(dht, hash, false)
|
||||
res, err := nf.Find()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, node := range res.Nodes {
|
||||
send(dht, node.Addr(), &Request{
|
||||
Method: storeMethod,
|
||||
StoreArgs: &storeArgs{
|
||||
BlobHash: hash.RawString(),
|
||||
Value: storeArgsValue{
|
||||
Token: "",
|
||||
LbryID: dht.node.id,
|
||||
Port: dht.node.port,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
type nodeFinder struct {
|
||||
findValue bool // true if we're using findValue
|
||||
target bitmap
|
||||
|
|
|
@ -4,8 +4,6 @@ import (
|
|||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
)
|
||||
|
||||
func TestDHT_FindNodes(t *testing.T) {
|
||||
|
@ -22,7 +20,7 @@ func TestDHT_FindNodes(t *testing.T) {
|
|||
go dht1.Start()
|
||||
defer dht1.Shutdown()
|
||||
|
||||
time.Sleep(1 * time.Second)
|
||||
time.Sleep(1 * time.Second) // give dhts a chance to connect
|
||||
|
||||
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
|
||||
if err != nil {
|
||||
|
@ -42,13 +40,15 @@ func TestDHT_FindNodes(t *testing.T) {
|
|||
|
||||
time.Sleep(1 * time.Second) // give dhts a chance to connect
|
||||
|
||||
foundNodes, err := dht3.FindNodes(id2)
|
||||
foundNodes, found, err := dht3.Get(newRandomBitmap())
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
spew.Dump(foundNodes)
|
||||
if found {
|
||||
t.Fatal("something was found, but it should not have been")
|
||||
}
|
||||
|
||||
if len(foundNodes) != 2 {
|
||||
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
|
||||
|
@ -74,7 +74,7 @@ func TestDHT_FindNodes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestDHT_FindValue(t *testing.T) {
|
||||
func TestDHT_Get(t *testing.T) {
|
||||
id1 := newRandomBitmap()
|
||||
id2 := newRandomBitmap()
|
||||
id3 := newRandomBitmap()
|
||||
|
@ -111,7 +111,7 @@ func TestDHT_FindValue(t *testing.T) {
|
|||
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
|
||||
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
|
||||
|
||||
foundNodes, found, err := dht3.FindValue(nodeToFind.id)
|
||||
foundNodes, found, err := dht3.Get(nodeToFind.id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/lbryio/errors.go"
|
||||
|
@ -41,9 +43,39 @@ type Message interface {
|
|||
bencode.Marshaler
|
||||
}
|
||||
|
||||
type messageID [messageIDLength]byte
|
||||
|
||||
func (m messageID) HexShort() string {
|
||||
return hex.EncodeToString(m[:])[:8]
|
||||
}
|
||||
|
||||
func (m *messageID) UnmarshalBencode(encoded []byte) error {
|
||||
var str string
|
||||
err := bencode.DecodeBytes(encoded, &str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
copy(m[:], str)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m messageID) MarshalBencode() ([]byte, error) {
|
||||
str := string(m[:])
|
||||
return bencode.EncodeBytes(str)
|
||||
}
|
||||
|
||||
func newMessageID() messageID {
|
||||
var m messageID
|
||||
_, err := rand.Read(m[:])
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
ID string
|
||||
NodeID string
|
||||
ID messageID
|
||||
NodeID bitmap
|
||||
Method string
|
||||
Args []string
|
||||
StoreArgs *storeArgs
|
||||
|
@ -67,8 +99,8 @@ func (r Request) MarshalBencode() ([]byte, error) {
|
|||
|
||||
func (r *Request) UnmarshalBencode(b []byte) error {
|
||||
var raw struct {
|
||||
ID string `bencode:"1"`
|
||||
NodeID string `bencode:"2"`
|
||||
ID messageID `bencode:"1"`
|
||||
NodeID bitmap `bencode:"2"`
|
||||
Method string `bencode:"3"`
|
||||
Args bencode.RawMessage `bencode:"4"`
|
||||
}
|
||||
|
@ -94,13 +126,15 @@ func (r *Request) UnmarshalBencode(b []byte) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
type storeArgsValue struct {
|
||||
Token string `bencode:"token"`
|
||||
LbryID bitmap `bencode:"lbryid"`
|
||||
Port int `bencode:"port"`
|
||||
}
|
||||
|
||||
type storeArgs struct {
|
||||
BlobHash string
|
||||
Value struct {
|
||||
Token string `bencode:"token"`
|
||||
LbryID string `bencode:"lbryid"`
|
||||
Port int `bencode:"port"`
|
||||
}
|
||||
BlobHash string
|
||||
Value storeArgsValue
|
||||
NodeID bitmap
|
||||
SelfStore bool // this is an int on the wire
|
||||
}
|
||||
|
@ -167,8 +201,8 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error {
|
|||
}
|
||||
|
||||
type Response struct {
|
||||
ID string
|
||||
NodeID string
|
||||
ID messageID
|
||||
NodeID bitmap
|
||||
Data string
|
||||
FindNodeData []Node
|
||||
FindValueKey string
|
||||
|
@ -219,8 +253,8 @@ func (r Response) MarshalBencode() ([]byte, error) {
|
|||
|
||||
func (r *Response) UnmarshalBencode(b []byte) error {
|
||||
var raw struct {
|
||||
ID string `bencode:"1"`
|
||||
NodeID string `bencode:"2"`
|
||||
ID messageID `bencode:"1"`
|
||||
NodeID bitmap `bencode:"2"`
|
||||
Data bencode.RawMessage `bencode:"3"`
|
||||
}
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
|
@ -269,10 +303,10 @@ func (r *Response) UnmarshalBencode(b []byte) error {
|
|||
}
|
||||
|
||||
type Error struct {
|
||||
ID string
|
||||
NodeID string
|
||||
Response []string
|
||||
ID messageID
|
||||
NodeID bitmap
|
||||
ExceptionType string
|
||||
Response []string
|
||||
}
|
||||
|
||||
func (e Error) MarshalBencode() ([]byte, error) {
|
||||
|
@ -284,3 +318,29 @@ func (e Error) MarshalBencode() ([]byte, error) {
|
|||
headerArgsField: e.Response,
|
||||
})
|
||||
}
|
||||
|
||||
func (e *Error) UnmarshalBencode(b []byte) error {
|
||||
var raw struct {
|
||||
ID messageID `bencode:"1"`
|
||||
NodeID bitmap `bencode:"2"`
|
||||
ExceptionType string `bencode:"3"`
|
||||
Args interface{} `bencode:"4"`
|
||||
}
|
||||
err := bencode.DecodeBytes(b, &raw)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
e.ID = raw.ID
|
||||
e.NodeID = raw.NodeID
|
||||
e.ExceptionType = raw.ExceptionType
|
||||
|
||||
if reflect.TypeOf(raw.Args).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(raw.Args)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
e.Response = append(e.Response, cast.ToString(v.Index(i).Interface()))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
|
|||
if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) {
|
||||
t.Error("blob hash mismatch")
|
||||
}
|
||||
if hex.EncodeToString([]byte(storeArgs.Value.LbryID)) != strings.ToLower(lbryID) {
|
||||
if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) {
|
||||
t.Error("lbryid mismatch")
|
||||
}
|
||||
if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port {
|
||||
|
@ -76,7 +76,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
|
|||
func TestBencodeFindNodesResponse(t *testing.T) {
|
||||
res := Response{
|
||||
ID: newMessageID(),
|
||||
NodeID: newRandomBitmap().RawString(),
|
||||
NodeID: newRandomBitmap(),
|
||||
FindNodeData: []Node{
|
||||
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
|
||||
{id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
|
||||
|
@ -100,7 +100,7 @@ func TestBencodeFindNodesResponse(t *testing.T) {
|
|||
func TestBencodeFindValueResponse(t *testing.T) {
|
||||
res := Response{
|
||||
ID: newMessageID(),
|
||||
NodeID: newRandomBitmap().RawString(),
|
||||
NodeID: newRandomBitmap(),
|
||||
FindValueKey: newRandomBitmap().RawString(),
|
||||
FindNodeData: []Node{
|
||||
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
|
||||
|
@ -125,8 +125,8 @@ func compareResponses(t *testing.T, res, res2 Response) {
|
|||
if res.ID != res2.ID {
|
||||
t.Errorf("expected ID %s, got %s", res.ID, res2.ID)
|
||||
}
|
||||
if res.NodeID != res2.NodeID {
|
||||
t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID)
|
||||
if !res.NodeID.Equals(res2.NodeID) {
|
||||
t.Errorf("expected NodeID %s, got %s", res.NodeID.Hex(), res2.NodeID.Hex())
|
||||
}
|
||||
if res.Data != res2.Data {
|
||||
t.Errorf("expected Data %s, got %s", res.Data, res2.Data)
|
||||
|
|
59
dht/rpc.go
59
dht/rpc.go
|
@ -1,28 +1,16 @@
|
|||
package dht
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"net"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
"github.com/lyoshenka/bencode"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cast"
|
||||
)
|
||||
|
||||
func newMessageID() string {
|
||||
buf := make([]byte, messageIDLength)
|
||||
_, err := rand.Read(buf)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(buf)
|
||||
}
|
||||
|
||||
// handlePacket handles packets received from udp.
|
||||
func handlePacket(dht *DHT, pkt packet) {
|
||||
//log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||
|
@ -48,7 +36,7 @@ func handlePacket(dht *DHT, pkt packet) {
|
|||
log.Errorln(err)
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
|
||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args))
|
||||
handleRequest(dht, pkt.raddr, request)
|
||||
|
||||
case responseType:
|
||||
|
@ -58,17 +46,17 @@ func handlePacket(dht *DHT, pkt packet) {
|
|||
log.Errorln(err)
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.ArgsDebug())
|
||||
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
|
||||
handleResponse(dht, pkt.raddr, response)
|
||||
|
||||
case errorType:
|
||||
e := Error{
|
||||
ID: data[headerMessageIDField].(string),
|
||||
NodeID: data[headerNodeIDField].(string),
|
||||
ExceptionType: data[headerPayloadField].(string),
|
||||
Response: getArgs(data[headerArgsField]),
|
||||
e := Error{}
|
||||
err = bencode.DecodeBytes(pkt.data, &e)
|
||||
if err != nil {
|
||||
log.Errorln(err)
|
||||
return
|
||||
}
|
||||
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
|
||||
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
|
||||
handleError(dht, pkt.raddr, e)
|
||||
|
||||
default:
|
||||
|
@ -79,14 +67,14 @@ func handlePacket(dht *DHT, pkt packet) {
|
|||
|
||||
// handleRequest handles the requests received from udp.
|
||||
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||
if request.NodeID == dht.node.id.RawString() {
|
||||
if request.NodeID.Equals(dht.node.id) {
|
||||
log.Warn("ignoring self-request")
|
||||
return
|
||||
}
|
||||
|
||||
switch request.Method {
|
||||
case pingMethod:
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
|
||||
case storeMethod:
|
||||
if request.StoreArgs.BlobHash == "" {
|
||||
log.Errorln("blobhash is empty")
|
||||
|
@ -95,7 +83,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
|||
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
||||
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
||||
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
|
||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
||||
case findNodeMethod:
|
||||
if len(request.Args) < 1 {
|
||||
log.Errorln("nothing to find")
|
||||
|
@ -117,7 +105,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
|||
}
|
||||
|
||||
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()}
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id}
|
||||
response.FindValueKey = request.Args[0]
|
||||
response.FindNodeData = nodes
|
||||
send(dht, addr, response)
|
||||
|
@ -131,7 +119,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
|||
return
|
||||
}
|
||||
|
||||
node := Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
|
||||
node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port}
|
||||
dht.rt.Update(node)
|
||||
}
|
||||
|
||||
|
@ -139,7 +127,7 @@ func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
|||
nodeID := newBitmapFromString(request.Args[0])
|
||||
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
|
||||
if len(closestNodes) > 0 {
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
|
||||
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
|
||||
for i, n := range closestNodes {
|
||||
response.FindNodeData[i] = n
|
||||
}
|
||||
|
@ -156,14 +144,14 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
|
|||
tx.res <- &response
|
||||
}
|
||||
|
||||
node := Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
|
||||
node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port}
|
||||
dht.rt.Update(node)
|
||||
}
|
||||
|
||||
// handleError handles errors received from udp.
|
||||
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
|
||||
spew.Dump(e)
|
||||
node := Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
|
||||
node := Node{id: e.NodeID, ip: addr.IP, port: addr.Port}
|
||||
dht.rt.Update(node)
|
||||
}
|
||||
|
||||
|
@ -176,10 +164,10 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
|||
|
||||
if req, ok := data.(Request); ok {
|
||||
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
|
||||
dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args))
|
||||
dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, argsToString(req.Args))
|
||||
} else if res, ok := data.(Response); ok {
|
||||
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
|
||||
dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug())
|
||||
dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
|
||||
} else {
|
||||
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
|
||||
}
|
||||
|
@ -190,17 +178,6 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func getArgs(argsInt interface{}) []string {
|
||||
var args []string
|
||||
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
|
||||
v := reflect.ValueOf(argsInt)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
args = append(args, cast.ToString(v.Index(i).Interface()))
|
||||
}
|
||||
}
|
||||
return args
|
||||
}
|
||||
|
||||
func argsToString(args []string) string {
|
||||
argsCopy := make([]string, len(args))
|
||||
copy(argsCopy, args)
|
||||
|
|
|
@ -151,7 +151,7 @@ func TestPing(t *testing.T) {
|
|||
rMessageID, ok := response[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != messageID {
|
||||
} else if rMessageID != string(messageID[:]) {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
}
|
||||
|
@ -203,16 +203,18 @@ func TestStore(t *testing.T) {
|
|||
|
||||
storeRequest := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID.RawString(),
|
||||
NodeID: testNodeID,
|
||||
Method: storeMethod,
|
||||
StoreArgs: &storeArgs{
|
||||
BlobHash: blobHashToStore,
|
||||
NodeID: testNodeID,
|
||||
Value: storeArgsValue{
|
||||
Token: "arst",
|
||||
LbryID: testNodeID,
|
||||
Port: 9999,
|
||||
},
|
||||
NodeID: testNodeID,
|
||||
},
|
||||
}
|
||||
storeRequest.StoreArgs.Value.Token = "arst"
|
||||
storeRequest.StoreArgs.Value.LbryID = testNodeID.RawString()
|
||||
storeRequest.StoreArgs.Value.Port = 9999
|
||||
|
||||
_ = "64 " + // start message
|
||||
"313A30 693065" + // type: 0
|
||||
|
@ -305,7 +307,7 @@ func TestFindNode(t *testing.T) {
|
|||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID.RawString(),
|
||||
NodeID: testNodeID,
|
||||
Method: findNodeMethod,
|
||||
Args: []string{blobHashToFind},
|
||||
}
|
||||
|
@ -390,7 +392,7 @@ func TestFindValueExisting(t *testing.T) {
|
|||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID.RawString(),
|
||||
NodeID: testNodeID,
|
||||
Method: findValueMethod,
|
||||
Args: []string{valueToFind},
|
||||
}
|
||||
|
@ -468,7 +470,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
|||
|
||||
request := Request{
|
||||
ID: messageID,
|
||||
NodeID: testNodeID.RawString(),
|
||||
NodeID: testNodeID,
|
||||
Method: findValueMethod,
|
||||
Args: []string{valueToFind},
|
||||
}
|
||||
|
@ -517,7 +519,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
|||
verifyContacts(t, contacts, nodes)
|
||||
}
|
||||
|
||||
func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) {
|
||||
func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) {
|
||||
if len(resp) != 4 {
|
||||
t.Errorf("expected 4 response fields, got %d", len(resp))
|
||||
}
|
||||
|
@ -541,7 +543,7 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
|
|||
rMessageID, ok := resp[headerMessageIDField].(string)
|
||||
if !ok {
|
||||
t.Error("message ID is not a string")
|
||||
} else if rMessageID != messageID {
|
||||
} else if rMessageID != string(id[:]) {
|
||||
t.Error("unexpected message ID")
|
||||
}
|
||||
if len(rMessageID) != messageIDLength {
|
||||
|
|
|
@ -19,7 +19,7 @@ type transaction struct {
|
|||
// transactionManager represents the manager of transactions.
|
||||
type transactionManager struct {
|
||||
lock *sync.RWMutex
|
||||
transactions map[string]*transaction
|
||||
transactions map[messageID]*transaction
|
||||
dht *DHT
|
||||
}
|
||||
|
||||
|
@ -27,7 +27,7 @@ type transactionManager struct {
|
|||
func newTransactionManager(dht *DHT) *transactionManager {
|
||||
return &transactionManager{
|
||||
lock: &sync.RWMutex{},
|
||||
transactions: make(map[string]*transaction),
|
||||
transactions: make(map[messageID]*transaction),
|
||||
dht: dht,
|
||||
}
|
||||
}
|
||||
|
@ -40,14 +40,14 @@ func (tm *transactionManager) insert(trans *transaction) {
|
|||
}
|
||||
|
||||
// delete removes a transaction from transactionManager.
|
||||
func (tm *transactionManager) delete(transID string) {
|
||||
func (tm *transactionManager) delete(id messageID) {
|
||||
tm.lock.Lock()
|
||||
defer tm.lock.Unlock()
|
||||
delete(tm.transactions, transID)
|
||||
delete(tm.transactions, id)
|
||||
}
|
||||
|
||||
// find transaction for id. optionally ensure that addr matches node from transaction
|
||||
func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction {
|
||||
func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction {
|
||||
tm.lock.RLock()
|
||||
defer tm.lock.RUnlock()
|
||||
|
||||
|
@ -73,7 +73,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
|
|||
defer close(ch)
|
||||
|
||||
req.ID = newMessageID()
|
||||
req.NodeID = tm.dht.node.id.RawString()
|
||||
req.NodeID = tm.dht.node.id
|
||||
trans := &transaction{
|
||||
node: node,
|
||||
req: req,
|
||||
|
|
Loading…
Reference in a new issue