proper types for some IDs
This commit is contained in:
parent
a5ef461fc5
commit
ea8d0d1eed
8 changed files with 159 additions and 101 deletions
|
@ -5,21 +5,22 @@ import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"github.com/lbryio/errors.go"
|
||||||
"github.com/lyoshenka/bencode"
|
"github.com/lyoshenka/bencode"
|
||||||
)
|
)
|
||||||
|
|
||||||
type bitmap [nodeIDLength]byte
|
type bitmap [nodeIDLength]byte
|
||||||
|
|
||||||
func (b bitmap) RawString() string {
|
func (b bitmap) RawString() string {
|
||||||
return string(b[0:nodeIDLength])
|
return string(b[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) Hex() string {
|
func (b bitmap) Hex() string {
|
||||||
return hex.EncodeToString(b[0:nodeIDLength])
|
return hex.EncodeToString(b[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) HexShort() string {
|
func (b bitmap) HexShort() string {
|
||||||
return hex.EncodeToString(b[0:nodeIDLength])[:8]
|
return hex.EncodeToString(b[:4])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b bitmap) Equals(other bitmap) bool {
|
func (b bitmap) Equals(other bitmap) bool {
|
||||||
|
@ -66,6 +67,9 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
if len(str) != nodeIDLength {
|
||||||
|
return errors.Err("invalid node ID length")
|
||||||
|
}
|
||||||
copy(b[:], str)
|
copy(b[:], str)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
37
dht/dht.go
37
dht/dht.go
|
@ -203,7 +203,7 @@ func (dht *DHT) join() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// now call iterativeFind on yourself
|
// now call iterativeFind on yourself
|
||||||
_, err := dht.FindNodes(dht.node.id)
|
_, _, err := dht.Get(dht.node.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
|
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
|
||||||
}
|
}
|
||||||
|
@ -260,16 +260,7 @@ func printState(dht *DHT) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (dht *DHT) FindNodes(hash bitmap) ([]Node, error) {
|
func (dht *DHT) Get(hash bitmap) ([]Node, bool, error) {
|
||||||
nf := newNodeFinder(dht, hash, false)
|
|
||||||
res, err := nf.Find()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return res.Nodes, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
|
|
||||||
nf := newNodeFinder(dht, hash, true)
|
nf := newNodeFinder(dht, hash, true)
|
||||||
res, err := nf.Find()
|
res, err := nf.Find()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -278,6 +269,30 @@ func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
|
||||||
return res.Nodes, res.Found, nil
|
return res.Nodes, res.Found, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (dht *DHT) Put(hash bitmap) error {
|
||||||
|
nf := newNodeFinder(dht, hash, false)
|
||||||
|
res, err := nf.Find()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, node := range res.Nodes {
|
||||||
|
send(dht, node.Addr(), &Request{
|
||||||
|
Method: storeMethod,
|
||||||
|
StoreArgs: &storeArgs{
|
||||||
|
BlobHash: hash.RawString(),
|
||||||
|
Value: storeArgsValue{
|
||||||
|
Token: "",
|
||||||
|
LbryID: dht.node.id,
|
||||||
|
Port: dht.node.port,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
type nodeFinder struct {
|
type nodeFinder struct {
|
||||||
findValue bool // true if we're using findValue
|
findValue bool // true if we're using findValue
|
||||||
target bitmap
|
target bitmap
|
||||||
|
|
|
@ -4,8 +4,6 @@ import (
|
||||||
"net"
|
"net"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDHT_FindNodes(t *testing.T) {
|
func TestDHT_FindNodes(t *testing.T) {
|
||||||
|
@ -22,7 +20,7 @@ func TestDHT_FindNodes(t *testing.T) {
|
||||||
go dht1.Start()
|
go dht1.Start()
|
||||||
defer dht1.Shutdown()
|
defer dht1.Shutdown()
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second) // give dhts a chance to connect
|
||||||
|
|
||||||
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
|
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -42,13 +40,15 @@ func TestDHT_FindNodes(t *testing.T) {
|
||||||
|
|
||||||
time.Sleep(1 * time.Second) // give dhts a chance to connect
|
time.Sleep(1 * time.Second) // give dhts a chance to connect
|
||||||
|
|
||||||
foundNodes, err := dht3.FindNodes(id2)
|
foundNodes, found, err := dht3.Get(newRandomBitmap())
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
spew.Dump(foundNodes)
|
if found {
|
||||||
|
t.Fatal("something was found, but it should not have been")
|
||||||
|
}
|
||||||
|
|
||||||
if len(foundNodes) != 2 {
|
if len(foundNodes) != 2 {
|
||||||
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
|
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
|
||||||
|
@ -74,7 +74,7 @@ func TestDHT_FindNodes(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDHT_FindValue(t *testing.T) {
|
func TestDHT_Get(t *testing.T) {
|
||||||
id1 := newRandomBitmap()
|
id1 := newRandomBitmap()
|
||||||
id2 := newRandomBitmap()
|
id2 := newRandomBitmap()
|
||||||
id3 := newRandomBitmap()
|
id3 := newRandomBitmap()
|
||||||
|
@ -111,7 +111,7 @@ func TestDHT_FindValue(t *testing.T) {
|
||||||
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
|
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
|
||||||
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
|
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
|
||||||
|
|
||||||
foundNodes, found, err := dht3.FindValue(nodeToFind.id)
|
foundNodes, found, err := dht3.Get(nodeToFind.id)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package dht
|
package dht
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/lbryio/errors.go"
|
"github.com/lbryio/errors.go"
|
||||||
|
@ -41,9 +43,39 @@ type Message interface {
|
||||||
bencode.Marshaler
|
bencode.Marshaler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type messageID [messageIDLength]byte
|
||||||
|
|
||||||
|
func (m messageID) HexShort() string {
|
||||||
|
return hex.EncodeToString(m[:])[:8]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *messageID) UnmarshalBencode(encoded []byte) error {
|
||||||
|
var str string
|
||||||
|
err := bencode.DecodeBytes(encoded, &str)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
copy(m[:], str)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m messageID) MarshalBencode() ([]byte, error) {
|
||||||
|
str := string(m[:])
|
||||||
|
return bencode.EncodeBytes(str)
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMessageID() messageID {
|
||||||
|
var m messageID
|
||||||
|
_, err := rand.Read(m[:])
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
type Request struct {
|
type Request struct {
|
||||||
ID string
|
ID messageID
|
||||||
NodeID string
|
NodeID bitmap
|
||||||
Method string
|
Method string
|
||||||
Args []string
|
Args []string
|
||||||
StoreArgs *storeArgs
|
StoreArgs *storeArgs
|
||||||
|
@ -67,8 +99,8 @@ func (r Request) MarshalBencode() ([]byte, error) {
|
||||||
|
|
||||||
func (r *Request) UnmarshalBencode(b []byte) error {
|
func (r *Request) UnmarshalBencode(b []byte) error {
|
||||||
var raw struct {
|
var raw struct {
|
||||||
ID string `bencode:"1"`
|
ID messageID `bencode:"1"`
|
||||||
NodeID string `bencode:"2"`
|
NodeID bitmap `bencode:"2"`
|
||||||
Method string `bencode:"3"`
|
Method string `bencode:"3"`
|
||||||
Args bencode.RawMessage `bencode:"4"`
|
Args bencode.RawMessage `bencode:"4"`
|
||||||
}
|
}
|
||||||
|
@ -94,13 +126,15 @@ func (r *Request) UnmarshalBencode(b []byte) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type storeArgs struct {
|
type storeArgsValue struct {
|
||||||
BlobHash string
|
|
||||||
Value struct {
|
|
||||||
Token string `bencode:"token"`
|
Token string `bencode:"token"`
|
||||||
LbryID string `bencode:"lbryid"`
|
LbryID bitmap `bencode:"lbryid"`
|
||||||
Port int `bencode:"port"`
|
Port int `bencode:"port"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type storeArgs struct {
|
||||||
|
BlobHash string
|
||||||
|
Value storeArgsValue
|
||||||
NodeID bitmap
|
NodeID bitmap
|
||||||
SelfStore bool // this is an int on the wire
|
SelfStore bool // this is an int on the wire
|
||||||
}
|
}
|
||||||
|
@ -167,8 +201,8 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Response struct {
|
type Response struct {
|
||||||
ID string
|
ID messageID
|
||||||
NodeID string
|
NodeID bitmap
|
||||||
Data string
|
Data string
|
||||||
FindNodeData []Node
|
FindNodeData []Node
|
||||||
FindValueKey string
|
FindValueKey string
|
||||||
|
@ -219,8 +253,8 @@ func (r Response) MarshalBencode() ([]byte, error) {
|
||||||
|
|
||||||
func (r *Response) UnmarshalBencode(b []byte) error {
|
func (r *Response) UnmarshalBencode(b []byte) error {
|
||||||
var raw struct {
|
var raw struct {
|
||||||
ID string `bencode:"1"`
|
ID messageID `bencode:"1"`
|
||||||
NodeID string `bencode:"2"`
|
NodeID bitmap `bencode:"2"`
|
||||||
Data bencode.RawMessage `bencode:"3"`
|
Data bencode.RawMessage `bencode:"3"`
|
||||||
}
|
}
|
||||||
err := bencode.DecodeBytes(b, &raw)
|
err := bencode.DecodeBytes(b, &raw)
|
||||||
|
@ -269,10 +303,10 @@ func (r *Response) UnmarshalBencode(b []byte) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Error struct {
|
type Error struct {
|
||||||
ID string
|
ID messageID
|
||||||
NodeID string
|
NodeID bitmap
|
||||||
Response []string
|
|
||||||
ExceptionType string
|
ExceptionType string
|
||||||
|
Response []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e Error) MarshalBencode() ([]byte, error) {
|
func (e Error) MarshalBencode() ([]byte, error) {
|
||||||
|
@ -284,3 +318,29 @@ func (e Error) MarshalBencode() ([]byte, error) {
|
||||||
headerArgsField: e.Response,
|
headerArgsField: e.Response,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Error) UnmarshalBencode(b []byte) error {
|
||||||
|
var raw struct {
|
||||||
|
ID messageID `bencode:"1"`
|
||||||
|
NodeID bitmap `bencode:"2"`
|
||||||
|
ExceptionType string `bencode:"3"`
|
||||||
|
Args interface{} `bencode:"4"`
|
||||||
|
}
|
||||||
|
err := bencode.DecodeBytes(b, &raw)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
e.ID = raw.ID
|
||||||
|
e.NodeID = raw.NodeID
|
||||||
|
e.ExceptionType = raw.ExceptionType
|
||||||
|
|
||||||
|
if reflect.TypeOf(raw.Args).Kind() == reflect.Slice {
|
||||||
|
v := reflect.ValueOf(raw.Args)
|
||||||
|
for i := 0; i < v.Len(); i++ {
|
||||||
|
e.Response = append(e.Response, cast.ToString(v.Index(i).Interface()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -48,7 +48,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
|
||||||
if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) {
|
if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) {
|
||||||
t.Error("blob hash mismatch")
|
t.Error("blob hash mismatch")
|
||||||
}
|
}
|
||||||
if hex.EncodeToString([]byte(storeArgs.Value.LbryID)) != strings.ToLower(lbryID) {
|
if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) {
|
||||||
t.Error("lbryid mismatch")
|
t.Error("lbryid mismatch")
|
||||||
}
|
}
|
||||||
if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port {
|
if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port {
|
||||||
|
@ -76,7 +76,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
|
||||||
func TestBencodeFindNodesResponse(t *testing.T) {
|
func TestBencodeFindNodesResponse(t *testing.T) {
|
||||||
res := Response{
|
res := Response{
|
||||||
ID: newMessageID(),
|
ID: newMessageID(),
|
||||||
NodeID: newRandomBitmap().RawString(),
|
NodeID: newRandomBitmap(),
|
||||||
FindNodeData: []Node{
|
FindNodeData: []Node{
|
||||||
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
|
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
|
||||||
{id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
|
{id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
|
||||||
|
@ -100,7 +100,7 @@ func TestBencodeFindNodesResponse(t *testing.T) {
|
||||||
func TestBencodeFindValueResponse(t *testing.T) {
|
func TestBencodeFindValueResponse(t *testing.T) {
|
||||||
res := Response{
|
res := Response{
|
||||||
ID: newMessageID(),
|
ID: newMessageID(),
|
||||||
NodeID: newRandomBitmap().RawString(),
|
NodeID: newRandomBitmap(),
|
||||||
FindValueKey: newRandomBitmap().RawString(),
|
FindValueKey: newRandomBitmap().RawString(),
|
||||||
FindNodeData: []Node{
|
FindNodeData: []Node{
|
||||||
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
|
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
|
||||||
|
@ -125,8 +125,8 @@ func compareResponses(t *testing.T, res, res2 Response) {
|
||||||
if res.ID != res2.ID {
|
if res.ID != res2.ID {
|
||||||
t.Errorf("expected ID %s, got %s", res.ID, res2.ID)
|
t.Errorf("expected ID %s, got %s", res.ID, res2.ID)
|
||||||
}
|
}
|
||||||
if res.NodeID != res2.NodeID {
|
if !res.NodeID.Equals(res2.NodeID) {
|
||||||
t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID)
|
t.Errorf("expected NodeID %s, got %s", res.NodeID.Hex(), res2.NodeID.Hex())
|
||||||
}
|
}
|
||||||
if res.Data != res2.Data {
|
if res.Data != res2.Data {
|
||||||
t.Errorf("expected Data %s, got %s", res.Data, res2.Data)
|
t.Errorf("expected Data %s, got %s", res.Data, res2.Data)
|
||||||
|
|
59
dht/rpc.go
59
dht/rpc.go
|
@ -1,28 +1,16 @@
|
||||||
package dht
|
package dht
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/rand"
|
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"net"
|
"net"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
"github.com/davecgh/go-spew/spew"
|
||||||
"github.com/lyoshenka/bencode"
|
"github.com/lyoshenka/bencode"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cast"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func newMessageID() string {
|
|
||||||
buf := make([]byte, messageIDLength)
|
|
||||||
_, err := rand.Read(buf)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
return string(buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
// handlePacket handles packets received from udp.
|
// handlePacket handles packets received from udp.
|
||||||
func handlePacket(dht *DHT, pkt packet) {
|
func handlePacket(dht *DHT, pkt packet) {
|
||||||
//log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
|
//log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
|
||||||
|
@ -48,7 +36,7 @@ func handlePacket(dht *DHT, pkt packet) {
|
||||||
log.Errorln(err)
|
log.Errorln(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
|
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args))
|
||||||
handleRequest(dht, pkt.raddr, request)
|
handleRequest(dht, pkt.raddr, request)
|
||||||
|
|
||||||
case responseType:
|
case responseType:
|
||||||
|
@ -58,17 +46,17 @@ func handlePacket(dht *DHT, pkt packet) {
|
||||||
log.Errorln(err)
|
log.Errorln(err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.ArgsDebug())
|
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
|
||||||
handleResponse(dht, pkt.raddr, response)
|
handleResponse(dht, pkt.raddr, response)
|
||||||
|
|
||||||
case errorType:
|
case errorType:
|
||||||
e := Error{
|
e := Error{}
|
||||||
ID: data[headerMessageIDField].(string),
|
err = bencode.DecodeBytes(pkt.data, &e)
|
||||||
NodeID: data[headerNodeIDField].(string),
|
if err != nil {
|
||||||
ExceptionType: data[headerPayloadField].(string),
|
log.Errorln(err)
|
||||||
Response: getArgs(data[headerArgsField]),
|
return
|
||||||
}
|
}
|
||||||
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
|
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
|
||||||
handleError(dht, pkt.raddr, e)
|
handleError(dht, pkt.raddr, e)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -79,14 +67,14 @@ func handlePacket(dht *DHT, pkt packet) {
|
||||||
|
|
||||||
// handleRequest handles the requests received from udp.
|
// handleRequest handles the requests received from udp.
|
||||||
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
if request.NodeID == dht.node.id.RawString() {
|
if request.NodeID.Equals(dht.node.id) {
|
||||||
log.Warn("ignoring self-request")
|
log.Warn("ignoring self-request")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch request.Method {
|
switch request.Method {
|
||||||
case pingMethod:
|
case pingMethod:
|
||||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
|
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
|
||||||
case storeMethod:
|
case storeMethod:
|
||||||
if request.StoreArgs.BlobHash == "" {
|
if request.StoreArgs.BlobHash == "" {
|
||||||
log.Errorln("blobhash is empty")
|
log.Errorln("blobhash is empty")
|
||||||
|
@ -95,7 +83,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
// TODO: we should be sending the IP in the request, not just using the sender's IP
|
||||||
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
|
||||||
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
|
||||||
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
|
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
|
||||||
case findNodeMethod:
|
case findNodeMethod:
|
||||||
if len(request.Args) < 1 {
|
if len(request.Args) < 1 {
|
||||||
log.Errorln("nothing to find")
|
log.Errorln("nothing to find")
|
||||||
|
@ -117,7 +105,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
|
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
|
||||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()}
|
response := Response{ID: request.ID, NodeID: dht.node.id}
|
||||||
response.FindValueKey = request.Args[0]
|
response.FindValueKey = request.Args[0]
|
||||||
response.FindNodeData = nodes
|
response.FindNodeData = nodes
|
||||||
send(dht, addr, response)
|
send(dht, addr, response)
|
||||||
|
@ -131,7 +119,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
node := Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
|
node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port}
|
||||||
dht.rt.Update(node)
|
dht.rt.Update(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -139,7 +127,7 @@ func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
|
||||||
nodeID := newBitmapFromString(request.Args[0])
|
nodeID := newBitmapFromString(request.Args[0])
|
||||||
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
|
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
|
||||||
if len(closestNodes) > 0 {
|
if len(closestNodes) > 0 {
|
||||||
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
|
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
|
||||||
for i, n := range closestNodes {
|
for i, n := range closestNodes {
|
||||||
response.FindNodeData[i] = n
|
response.FindNodeData[i] = n
|
||||||
}
|
}
|
||||||
|
@ -156,14 +144,14 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
|
||||||
tx.res <- &response
|
tx.res <- &response
|
||||||
}
|
}
|
||||||
|
|
||||||
node := Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
|
node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port}
|
||||||
dht.rt.Update(node)
|
dht.rt.Update(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleError handles errors received from udp.
|
// handleError handles errors received from udp.
|
||||||
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
|
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
|
||||||
spew.Dump(e)
|
spew.Dump(e)
|
||||||
node := Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
|
node := Node{id: e.NodeID, ip: addr.IP, port: addr.Port}
|
||||||
dht.rt.Update(node)
|
dht.rt.Update(node)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,10 +164,10 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||||
|
|
||||||
if req, ok := data.(Request); ok {
|
if req, ok := data.(Request); ok {
|
||||||
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
|
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
|
||||||
dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args))
|
dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, argsToString(req.Args))
|
||||||
} else if res, ok := data.(Response); ok {
|
} else if res, ok := data.(Response); ok {
|
||||||
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
|
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
|
||||||
dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug())
|
dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
|
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
|
||||||
}
|
}
|
||||||
|
@ -190,17 +178,6 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func getArgs(argsInt interface{}) []string {
|
|
||||||
var args []string
|
|
||||||
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
|
|
||||||
v := reflect.ValueOf(argsInt)
|
|
||||||
for i := 0; i < v.Len(); i++ {
|
|
||||||
args = append(args, cast.ToString(v.Index(i).Interface()))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return args
|
|
||||||
}
|
|
||||||
|
|
||||||
func argsToString(args []string) string {
|
func argsToString(args []string) string {
|
||||||
argsCopy := make([]string, len(args))
|
argsCopy := make([]string, len(args))
|
||||||
copy(argsCopy, args)
|
copy(argsCopy, args)
|
||||||
|
|
|
@ -151,7 +151,7 @@ func TestPing(t *testing.T) {
|
||||||
rMessageID, ok := response[headerMessageIDField].(string)
|
rMessageID, ok := response[headerMessageIDField].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("message ID is not a string")
|
t.Error("message ID is not a string")
|
||||||
} else if rMessageID != messageID {
|
} else if rMessageID != string(messageID[:]) {
|
||||||
t.Error("unexpected message ID")
|
t.Error("unexpected message ID")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -203,16 +203,18 @@ func TestStore(t *testing.T) {
|
||||||
|
|
||||||
storeRequest := Request{
|
storeRequest := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID.RawString(),
|
NodeID: testNodeID,
|
||||||
Method: storeMethod,
|
Method: storeMethod,
|
||||||
StoreArgs: &storeArgs{
|
StoreArgs: &storeArgs{
|
||||||
BlobHash: blobHashToStore,
|
BlobHash: blobHashToStore,
|
||||||
|
Value: storeArgsValue{
|
||||||
|
Token: "arst",
|
||||||
|
LbryID: testNodeID,
|
||||||
|
Port: 9999,
|
||||||
|
},
|
||||||
NodeID: testNodeID,
|
NodeID: testNodeID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
storeRequest.StoreArgs.Value.Token = "arst"
|
|
||||||
storeRequest.StoreArgs.Value.LbryID = testNodeID.RawString()
|
|
||||||
storeRequest.StoreArgs.Value.Port = 9999
|
|
||||||
|
|
||||||
_ = "64 " + // start message
|
_ = "64 " + // start message
|
||||||
"313A30 693065" + // type: 0
|
"313A30 693065" + // type: 0
|
||||||
|
@ -305,7 +307,7 @@ func TestFindNode(t *testing.T) {
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID.RawString(),
|
NodeID: testNodeID,
|
||||||
Method: findNodeMethod,
|
Method: findNodeMethod,
|
||||||
Args: []string{blobHashToFind},
|
Args: []string{blobHashToFind},
|
||||||
}
|
}
|
||||||
|
@ -390,7 +392,7 @@ func TestFindValueExisting(t *testing.T) {
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID.RawString(),
|
NodeID: testNodeID,
|
||||||
Method: findValueMethod,
|
Method: findValueMethod,
|
||||||
Args: []string{valueToFind},
|
Args: []string{valueToFind},
|
||||||
}
|
}
|
||||||
|
@ -468,7 +470,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
||||||
|
|
||||||
request := Request{
|
request := Request{
|
||||||
ID: messageID,
|
ID: messageID,
|
||||||
NodeID: testNodeID.RawString(),
|
NodeID: testNodeID,
|
||||||
Method: findValueMethod,
|
Method: findValueMethod,
|
||||||
Args: []string{valueToFind},
|
Args: []string{valueToFind},
|
||||||
}
|
}
|
||||||
|
@ -517,7 +519,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
|
||||||
verifyContacts(t, contacts, nodes)
|
verifyContacts(t, contacts, nodes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) {
|
func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) {
|
||||||
if len(resp) != 4 {
|
if len(resp) != 4 {
|
||||||
t.Errorf("expected 4 response fields, got %d", len(resp))
|
t.Errorf("expected 4 response fields, got %d", len(resp))
|
||||||
}
|
}
|
||||||
|
@ -541,7 +543,7 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
|
||||||
rMessageID, ok := resp[headerMessageIDField].(string)
|
rMessageID, ok := resp[headerMessageIDField].(string)
|
||||||
if !ok {
|
if !ok {
|
||||||
t.Error("message ID is not a string")
|
t.Error("message ID is not a string")
|
||||||
} else if rMessageID != messageID {
|
} else if rMessageID != string(id[:]) {
|
||||||
t.Error("unexpected message ID")
|
t.Error("unexpected message ID")
|
||||||
}
|
}
|
||||||
if len(rMessageID) != messageIDLength {
|
if len(rMessageID) != messageIDLength {
|
||||||
|
|
|
@ -19,7 +19,7 @@ type transaction struct {
|
||||||
// transactionManager represents the manager of transactions.
|
// transactionManager represents the manager of transactions.
|
||||||
type transactionManager struct {
|
type transactionManager struct {
|
||||||
lock *sync.RWMutex
|
lock *sync.RWMutex
|
||||||
transactions map[string]*transaction
|
transactions map[messageID]*transaction
|
||||||
dht *DHT
|
dht *DHT
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,7 +27,7 @@ type transactionManager struct {
|
||||||
func newTransactionManager(dht *DHT) *transactionManager {
|
func newTransactionManager(dht *DHT) *transactionManager {
|
||||||
return &transactionManager{
|
return &transactionManager{
|
||||||
lock: &sync.RWMutex{},
|
lock: &sync.RWMutex{},
|
||||||
transactions: make(map[string]*transaction),
|
transactions: make(map[messageID]*transaction),
|
||||||
dht: dht,
|
dht: dht,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -40,14 +40,14 @@ func (tm *transactionManager) insert(trans *transaction) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// delete removes a transaction from transactionManager.
|
// delete removes a transaction from transactionManager.
|
||||||
func (tm *transactionManager) delete(transID string) {
|
func (tm *transactionManager) delete(id messageID) {
|
||||||
tm.lock.Lock()
|
tm.lock.Lock()
|
||||||
defer tm.lock.Unlock()
|
defer tm.lock.Unlock()
|
||||||
delete(tm.transactions, transID)
|
delete(tm.transactions, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
// find transaction for id. optionally ensure that addr matches node from transaction
|
// find transaction for id. optionally ensure that addr matches node from transaction
|
||||||
func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction {
|
func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction {
|
||||||
tm.lock.RLock()
|
tm.lock.RLock()
|
||||||
defer tm.lock.RUnlock()
|
defer tm.lock.RUnlock()
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
|
||||||
defer close(ch)
|
defer close(ch)
|
||||||
|
|
||||||
req.ID = newMessageID()
|
req.ID = newMessageID()
|
||||||
req.NodeID = tm.dht.node.id.RawString()
|
req.NodeID = tm.dht.node.id
|
||||||
trans := &transaction{
|
trans := &transaction{
|
||||||
node: node,
|
node: node,
|
||||||
req: req,
|
req: req,
|
||||||
|
|
Loading…
Reference in a new issue