proper types for some IDs

This commit is contained in:
Alex Grintsvayg 2018-04-03 13:38:01 -04:00
parent 13f991852b
commit 868e243afc
8 changed files with 159 additions and 101 deletions

View file

@ -5,21 +5,22 @@ import (
"encoding/hex"
"strconv"
"github.com/lbryio/errors.go"
"github.com/lyoshenka/bencode"
)
type bitmap [nodeIDLength]byte
func (b bitmap) RawString() string {
return string(b[0:nodeIDLength])
return string(b[:])
}
func (b bitmap) Hex() string {
return hex.EncodeToString(b[0:nodeIDLength])
return hex.EncodeToString(b[:])
}
func (b bitmap) HexShort() string {
return hex.EncodeToString(b[0:nodeIDLength])[:8]
return hex.EncodeToString(b[:4])
}
func (b bitmap) Equals(other bitmap) bool {
@ -66,6 +67,9 @@ func (b *bitmap) UnmarshalBencode(encoded []byte) error {
if err != nil {
return err
}
if len(str) != nodeIDLength {
return errors.Err("invalid node ID length")
}
copy(b[:], str)
return nil
}

View file

@ -203,7 +203,7 @@ func (dht *DHT) join() {
}
// now call iterativeFind on yourself
_, err := dht.FindNodes(dht.node.id)
_, _, err := dht.Get(dht.node.id)
if err != nil {
log.Errorf("[%s] join: %s", dht.node.id.HexShort(), err.Error())
}
@ -260,16 +260,7 @@ func printState(dht *DHT) {
}
}
func (dht *DHT) FindNodes(hash bitmap) ([]Node, error) {
nf := newNodeFinder(dht, hash, false)
res, err := nf.Find()
if err != nil {
return nil, err
}
return res.Nodes, nil
}
func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
func (dht *DHT) Get(hash bitmap) ([]Node, bool, error) {
nf := newNodeFinder(dht, hash, true)
res, err := nf.Find()
if err != nil {
@ -278,6 +269,30 @@ func (dht *DHT) FindValue(hash bitmap) ([]Node, bool, error) {
return res.Nodes, res.Found, nil
}
func (dht *DHT) Put(hash bitmap) error {
nf := newNodeFinder(dht, hash, false)
res, err := nf.Find()
if err != nil {
return err
}
for _, node := range res.Nodes {
send(dht, node.Addr(), &Request{
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: hash.RawString(),
Value: storeArgsValue{
Token: "",
LbryID: dht.node.id,
Port: dht.node.port,
},
},
})
}
return nil
}
type nodeFinder struct {
findValue bool // true if we're using findValue
target bitmap

View file

@ -4,8 +4,6 @@ import (
"net"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
)
func TestDHT_FindNodes(t *testing.T) {
@ -22,7 +20,7 @@ func TestDHT_FindNodes(t *testing.T) {
go dht1.Start()
defer dht1.Shutdown()
time.Sleep(1 * time.Second)
time.Sleep(1 * time.Second) // give dhts a chance to connect
dht2, err := New(&Config{Address: "127.0.0.1:21217", NodeID: id2.Hex(), SeedNodes: []string{seedIP}})
if err != nil {
@ -42,13 +40,15 @@ func TestDHT_FindNodes(t *testing.T) {
time.Sleep(1 * time.Second) // give dhts a chance to connect
foundNodes, err := dht3.FindNodes(id2)
foundNodes, found, err := dht3.Get(newRandomBitmap())
if err != nil {
t.Fatal(err)
}
spew.Dump(foundNodes)
if found {
t.Fatal("something was found, but it should not have been")
}
if len(foundNodes) != 2 {
t.Errorf("expected 2 nodes, found %d", len(foundNodes))
@ -74,7 +74,7 @@ func TestDHT_FindNodes(t *testing.T) {
}
}
func TestDHT_FindValue(t *testing.T) {
func TestDHT_Get(t *testing.T) {
id1 := newRandomBitmap()
id2 := newRandomBitmap()
id3 := newRandomBitmap()
@ -111,7 +111,7 @@ func TestDHT_FindValue(t *testing.T) {
nodeToFind := Node{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4), port: 5678}
dht1.store.Upsert(nodeToFind.id.RawString(), nodeToFind)
foundNodes, found, err := dht3.FindValue(nodeToFind.id)
foundNodes, found, err := dht3.Get(nodeToFind.id)
if err != nil {
t.Fatal(err)
}

View file

@ -1,7 +1,9 @@
package dht
import (
"crypto/rand"
"encoding/hex"
"reflect"
"strings"
"github.com/lbryio/errors.go"
@ -41,9 +43,39 @@ type Message interface {
bencode.Marshaler
}
type messageID [messageIDLength]byte
func (m messageID) HexShort() string {
return hex.EncodeToString(m[:])[:8]
}
func (m *messageID) UnmarshalBencode(encoded []byte) error {
var str string
err := bencode.DecodeBytes(encoded, &str)
if err != nil {
return err
}
copy(m[:], str)
return nil
}
func (m messageID) MarshalBencode() ([]byte, error) {
str := string(m[:])
return bencode.EncodeBytes(str)
}
func newMessageID() messageID {
var m messageID
_, err := rand.Read(m[:])
if err != nil {
panic(err)
}
return m
}
type Request struct {
ID string
NodeID string
ID messageID
NodeID bitmap
Method string
Args []string
StoreArgs *storeArgs
@ -67,8 +99,8 @@ func (r Request) MarshalBencode() ([]byte, error) {
func (r *Request) UnmarshalBencode(b []byte) error {
var raw struct {
ID string `bencode:"1"`
NodeID string `bencode:"2"`
ID messageID `bencode:"1"`
NodeID bitmap `bencode:"2"`
Method string `bencode:"3"`
Args bencode.RawMessage `bencode:"4"`
}
@ -94,13 +126,15 @@ func (r *Request) UnmarshalBencode(b []byte) error {
return nil
}
type storeArgsValue struct {
Token string `bencode:"token"`
LbryID bitmap `bencode:"lbryid"`
Port int `bencode:"port"`
}
type storeArgs struct {
BlobHash string
Value struct {
Token string `bencode:"token"`
LbryID string `bencode:"lbryid"`
Port int `bencode:"port"`
}
BlobHash string
Value storeArgsValue
NodeID bitmap
SelfStore bool // this is an int on the wire
}
@ -167,8 +201,8 @@ func (s *storeArgs) UnmarshalBencode(b []byte) error {
}
type Response struct {
ID string
NodeID string
ID messageID
NodeID bitmap
Data string
FindNodeData []Node
FindValueKey string
@ -219,8 +253,8 @@ func (r Response) MarshalBencode() ([]byte, error) {
func (r *Response) UnmarshalBencode(b []byte) error {
var raw struct {
ID string `bencode:"1"`
NodeID string `bencode:"2"`
ID messageID `bencode:"1"`
NodeID bitmap `bencode:"2"`
Data bencode.RawMessage `bencode:"3"`
}
err := bencode.DecodeBytes(b, &raw)
@ -269,10 +303,10 @@ func (r *Response) UnmarshalBencode(b []byte) error {
}
type Error struct {
ID string
NodeID string
Response []string
ID messageID
NodeID bitmap
ExceptionType string
Response []string
}
func (e Error) MarshalBencode() ([]byte, error) {
@ -284,3 +318,29 @@ func (e Error) MarshalBencode() ([]byte, error) {
headerArgsField: e.Response,
})
}
func (e *Error) UnmarshalBencode(b []byte) error {
var raw struct {
ID messageID `bencode:"1"`
NodeID bitmap `bencode:"2"`
ExceptionType string `bencode:"3"`
Args interface{} `bencode:"4"`
}
err := bencode.DecodeBytes(b, &raw)
if err != nil {
return err
}
e.ID = raw.ID
e.NodeID = raw.NodeID
e.ExceptionType = raw.ExceptionType
if reflect.TypeOf(raw.Args).Kind() == reflect.Slice {
v := reflect.ValueOf(raw.Args)
for i := 0; i < v.Len(); i++ {
e.Response = append(e.Response, cast.ToString(v.Index(i).Interface()))
}
}
return nil
}

View file

@ -48,7 +48,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) {
t.Error("blob hash mismatch")
}
if hex.EncodeToString([]byte(storeArgs.Value.LbryID)) != strings.ToLower(lbryID) {
if storeArgs.Value.LbryID.Hex() != strings.ToLower(lbryID) {
t.Error("lbryid mismatch")
}
if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port {
@ -76,7 +76,7 @@ func TestBencodeDecodeStoreArgs(t *testing.T) {
func TestBencodeFindNodesResponse(t *testing.T) {
res := Response{
ID: newMessageID(),
NodeID: newRandomBitmap().RawString(),
NodeID: newRandomBitmap(),
FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
{id: newRandomBitmap(), ip: net.IPv4(4, 3, 2, 1).To4(), port: 8765},
@ -100,7 +100,7 @@ func TestBencodeFindNodesResponse(t *testing.T) {
func TestBencodeFindValueResponse(t *testing.T) {
res := Response{
ID: newMessageID(),
NodeID: newRandomBitmap().RawString(),
NodeID: newRandomBitmap(),
FindValueKey: newRandomBitmap().RawString(),
FindNodeData: []Node{
{id: newRandomBitmap(), ip: net.IPv4(1, 2, 3, 4).To4(), port: 5678},
@ -125,8 +125,8 @@ func compareResponses(t *testing.T, res, res2 Response) {
if res.ID != res2.ID {
t.Errorf("expected ID %s, got %s", res.ID, res2.ID)
}
if res.NodeID != res2.NodeID {
t.Errorf("expected NodeID %s, got %s", res.NodeID, res2.NodeID)
if !res.NodeID.Equals(res2.NodeID) {
t.Errorf("expected NodeID %s, got %s", res.NodeID.Hex(), res2.NodeID.Hex())
}
if res.Data != res2.Data {
t.Errorf("expected Data %s, got %s", res.Data, res2.Data)

View file

@ -1,28 +1,16 @@
package dht
import (
"crypto/rand"
"encoding/hex"
"net"
"reflect"
"strings"
"time"
"github.com/davecgh/go-spew/spew"
"github.com/lyoshenka/bencode"
log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
)
func newMessageID() string {
buf := make([]byte, messageIDLength)
_, err := rand.Read(buf)
if err != nil {
panic(err)
}
return string(buf)
}
// handlePacket handles packets received from udp.
func handlePacket(dht *DHT, pkt packet) {
//log.Debugf("[%s] Received message from %s:%s (%d bytes) %s", dht.node.id.HexShort(), pkt.raddr.IP.String(), strconv.Itoa(pkt.raddr.Port), len(pkt.data), hex.EncodeToString(pkt.data))
@ -48,7 +36,7 @@ func handlePacket(dht *DHT, pkt packet) {
log.Errorln(err)
return
}
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.HexShort(), request.ID.HexShort(), request.NodeID.HexShort(), request.Method, argsToString(request.Args))
handleRequest(dht, pkt.raddr, request)
case responseType:
@ -58,17 +46,17 @@ func handlePacket(dht *DHT, pkt packet) {
log.Errorln(err)
return
}
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.ArgsDebug())
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.HexShort(), response.ID.HexShort(), response.NodeID.HexShort(), response.ArgsDebug())
handleResponse(dht, pkt.raddr, response)
case errorType:
e := Error{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
ExceptionType: data[headerPayloadField].(string),
Response: getArgs(data[headerArgsField]),
e := Error{}
err = bencode.DecodeBytes(pkt.data, &e)
if err != nil {
log.Errorln(err)
return
}
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.HexShort(), e.ID.HexShort(), e.NodeID.HexShort(), e.ExceptionType)
handleError(dht, pkt.raddr, e)
default:
@ -79,14 +67,14 @@ func handlePacket(dht *DHT, pkt packet) {
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
if request.NodeID == dht.node.id.RawString() {
if request.NodeID.Equals(dht.node.id) {
log.Warn("ignoring self-request")
return
}
switch request.Method {
case pingMethod:
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: pingSuccessResponse})
case storeMethod:
if request.StoreArgs.BlobHash == "" {
log.Errorln("blobhash is empty")
@ -95,7 +83,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
// TODO: we should be sending the IP in the request, not just using the sender's IP
// TODO: should we be using StoreArgs.NodeID or StoreArgs.Value.LbryID ???
dht.store.Upsert(request.StoreArgs.BlobHash, Node{id: request.StoreArgs.NodeID, ip: addr.IP, port: request.StoreArgs.Value.Port})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id, Data: storeSuccessResponse})
case findNodeMethod:
if len(request.Args) < 1 {
log.Errorln("nothing to find")
@ -117,7 +105,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
}
if nodes := dht.store.Get(request.Args[0]); len(nodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString()}
response := Response{ID: request.ID, NodeID: dht.node.id}
response.FindValueKey = request.Args[0]
response.FindNodeData = nodes
send(dht, addr, response)
@ -131,7 +119,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) {
return
}
node := Node{id: newBitmapFromString(request.NodeID), ip: addr.IP, port: addr.Port}
node := Node{id: request.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
@ -139,7 +127,7 @@ func doFindNodes(dht *DHT, addr *net.UDPAddr, request Request) {
nodeID := newBitmapFromString(request.Args[0])
closestNodes := dht.rt.GetClosest(nodeID, bucketSize)
if len(closestNodes) > 0 {
response := Response{ID: request.ID, NodeID: dht.node.id.RawString(), FindNodeData: make([]Node, len(closestNodes))}
response := Response{ID: request.ID, NodeID: dht.node.id, FindNodeData: make([]Node, len(closestNodes))}
for i, n := range closestNodes {
response.FindNodeData[i] = n
}
@ -156,14 +144,14 @@ func handleResponse(dht *DHT, addr *net.UDPAddr, response Response) {
tx.res <- &response
}
node := Node{id: newBitmapFromString(response.NodeID), ip: addr.IP, port: addr.Port}
node := Node{id: response.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
// handleError handles errors received from udp.
func handleError(dht *DHT, addr *net.UDPAddr, e Error) {
spew.Dump(e)
node := Node{id: newBitmapFromString(e.NodeID), ip: addr.IP, port: addr.Port}
node := Node{id: e.NodeID, ip: addr.IP, port: addr.Port}
dht.rt.Update(node)
}
@ -176,10 +164,10 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Debugf("[%s] query %s: sending request to %s (%d bytes) %s(%s)",
dht.node.id.HexShort(), hex.EncodeToString([]byte(req.ID))[:8], addr.String(), len(encoded), req.Method, argsToString(req.Args))
dht.node.id.HexShort(), req.ID.HexShort(), addr.String(), len(encoded), req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response to %s (%d bytes) %s",
dht.node.id.HexShort(), hex.EncodeToString([]byte(res.ID))[:8], addr.String(), len(encoded), res.ArgsDebug())
dht.node.id.HexShort(), res.ID.HexShort(), addr.String(), len(encoded), res.ArgsDebug())
} else {
log.Debugf("[%s] (%d bytes) %s", dht.node.id.HexShort(), len(encoded), spew.Sdump(data))
}
@ -190,17 +178,6 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
return err
}
func getArgs(argsInt interface{}) []string {
var args []string
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
v := reflect.ValueOf(argsInt)
for i := 0; i < v.Len(); i++ {
args = append(args, cast.ToString(v.Index(i).Interface()))
}
}
return args
}
func argsToString(args []string) string {
argsCopy := make([]string, len(args))
copy(argsCopy, args)

View file

@ -151,7 +151,7 @@ func TestPing(t *testing.T) {
rMessageID, ok := response[headerMessageIDField].(string)
if !ok {
t.Error("message ID is not a string")
} else if rMessageID != messageID {
} else if rMessageID != string(messageID[:]) {
t.Error("unexpected message ID")
}
}
@ -203,16 +203,18 @@ func TestStore(t *testing.T) {
storeRequest := Request{
ID: messageID,
NodeID: testNodeID.RawString(),
NodeID: testNodeID,
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: blobHashToStore,
NodeID: testNodeID,
Value: storeArgsValue{
Token: "arst",
LbryID: testNodeID,
Port: 9999,
},
NodeID: testNodeID,
},
}
storeRequest.StoreArgs.Value.Token = "arst"
storeRequest.StoreArgs.Value.LbryID = testNodeID.RawString()
storeRequest.StoreArgs.Value.Port = 9999
_ = "64 " + // start message
"313A30 693065" + // type: 0
@ -305,7 +307,7 @@ func TestFindNode(t *testing.T) {
request := Request{
ID: messageID,
NodeID: testNodeID.RawString(),
NodeID: testNodeID,
Method: findNodeMethod,
Args: []string{blobHashToFind},
}
@ -390,7 +392,7 @@ func TestFindValueExisting(t *testing.T) {
request := Request{
ID: messageID,
NodeID: testNodeID.RawString(),
NodeID: testNodeID,
Method: findValueMethod,
Args: []string{valueToFind},
}
@ -468,7 +470,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
request := Request{
ID: messageID,
NodeID: testNodeID.RawString(),
NodeID: testNodeID,
Method: findValueMethod,
Args: []string{valueToFind},
}
@ -517,7 +519,7 @@ func TestFindValueFallbackToFindNode(t *testing.T) {
verifyContacts(t, contacts, nodes)
}
func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNodeID string) {
func verifyResponse(t *testing.T, resp map[string]interface{}, id messageID, dhtNodeID string) {
if len(resp) != 4 {
t.Errorf("expected 4 response fields, got %d", len(resp))
}
@ -541,7 +543,7 @@ func verifyResponse(t *testing.T, resp map[string]interface{}, messageID, dhtNod
rMessageID, ok := resp[headerMessageIDField].(string)
if !ok {
t.Error("message ID is not a string")
} else if rMessageID != messageID {
} else if rMessageID != string(id[:]) {
t.Error("unexpected message ID")
}
if len(rMessageID) != messageIDLength {

View file

@ -19,7 +19,7 @@ type transaction struct {
// transactionManager represents the manager of transactions.
type transactionManager struct {
lock *sync.RWMutex
transactions map[string]*transaction
transactions map[messageID]*transaction
dht *DHT
}
@ -27,7 +27,7 @@ type transactionManager struct {
func newTransactionManager(dht *DHT) *transactionManager {
return &transactionManager{
lock: &sync.RWMutex{},
transactions: make(map[string]*transaction),
transactions: make(map[messageID]*transaction),
dht: dht,
}
}
@ -40,14 +40,14 @@ func (tm *transactionManager) insert(trans *transaction) {
}
// delete removes a transaction from transactionManager.
func (tm *transactionManager) delete(transID string) {
func (tm *transactionManager) delete(id messageID) {
tm.lock.Lock()
defer tm.lock.Unlock()
delete(tm.transactions, transID)
delete(tm.transactions, id)
}
// find transaction for id. optionally ensure that addr matches node from transaction
func (tm *transactionManager) Find(id string, addr *net.UDPAddr) *transaction {
func (tm *transactionManager) Find(id messageID, addr *net.UDPAddr) *transaction {
tm.lock.RLock()
defer tm.lock.RUnlock()
@ -73,7 +73,7 @@ func (tm *transactionManager) SendAsync(ctx context.Context, node Node, req *Req
defer close(ch)
req.ID = newMessageID()
req.NodeID = tm.dht.node.id.RawString()
req.NodeID = tm.dht.node.id
trans := &transaction{
node: node,
req: req,