better tests, better bencoding

This commit is contained in:
Alex Grintsvayg 2018-03-07 16:15:58 -05:00
parent f565d0b78f
commit 006a49bd67
6 changed files with 517 additions and 169 deletions

View file

@ -52,6 +52,7 @@ type DHT struct {
node *Node
routingTable *RoutingTable
packets chan packet
store *peerStore
}
// New returns a DHT pointer. If config is nil, then config will be set to the default config.
@ -72,6 +73,7 @@ func New(config *Config) *DHT {
node: node,
routingTable: NewRoutingTable(node),
packets: make(chan packet),
store: newPeerStore(),
}
}
@ -150,39 +152,33 @@ func handle(dht *DHT, pkt packet) {
var data map[string]interface{}
err := bencode.DecodeBytes(pkt.data, &data)
if err != nil {
log.Errorf("Error decoding data: %s\n%s", err, pkt.data)
log.Errorf("error decoding data: %s\n%s", err, pkt.data)
return
}
msgType, ok := data[headerTypeField]
if !ok {
log.Errorf("Decoded data has no message type: %s", data)
log.Errorf("decoded data has no message type: %s", data)
return
}
switch msgType.(int64) {
case requestType:
request := Request{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
Method: data[headerPayloadField].(string),
Args: getArgs(data[headerArgsField]),
request := Request{}
err = bencode.DecodeBytes(pkt.data, &request)
if err != nil {
return
}
log.Infof("%s: Received from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
log.Debugf("[%s] query %s: received request from %s: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(request.ID))[:8], hex.EncodeToString([]byte(request.NodeID))[:8], request.Method, argsToString(request.Args))
handleRequest(dht, pkt.raddr, request)
case responseType:
response := Response{
ID: data[headerMessageIDField].(string),
NodeID: data[headerNodeIDField].(string),
response := Response{}
err = bencode.DecodeBytes(pkt.data, &response)
if err != nil {
return
}
if reflect.TypeOf(data[headerPayloadField]).Kind() == reflect.String {
response.Data = data[headerPayloadField].(string)
} else {
response.FindNodeData = getFindNodeResponse(data[headerPayloadField])
}
log.Debugf("[%s] query %s: received response from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(response.ID))[:8], hex.EncodeToString([]byte(response.NodeID))[:8], response.Data)
handleResponse(dht, pkt.raddr, response)
case errorType:
@ -192,6 +188,7 @@ func handle(dht *DHT, pkt packet) {
ExceptionType: data[headerPayloadField].(string),
Response: getArgs(data[headerArgsField]),
}
log.Debugf("[%s] query %s: received error from %s: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(e.ID))[:8], hex.EncodeToString([]byte(e.NodeID))[:8], e.ExceptionType)
handleError(dht, pkt.raddr, e)
default:
@ -202,7 +199,6 @@ func handle(dht *DHT, pkt packet) {
// handleRequest handles the requests received from udp.
func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool) {
log.Infoln("handling request")
if request.NodeID == dht.node.id.RawString() {
log.Warn("ignoring self-request")
return
@ -211,11 +207,16 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool)
switch request.Method {
case pingMethod:
log.Println("ping")
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: "pong"})
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: pingSuccessResponse})
case storeMethod:
log.Println("store")
node := &Node{id: newBitmapFromHex(request.StoreArgs.Value.LbryID), addr: request.StoreArgs.Value.Port}
dht.store.Insert(newBitmapFromHex(request.StoreArgs.BlobHash))
send(dht, addr, Response{ID: request.ID, NodeID: dht.node.id.RawString(), Data: storeSuccessResponse})
case findNodeMethod:
log.Println("findnode")
case findValueMethod:
log.Println("findvalue")
//if len(request.Args) < 1 {
// send(dht, addr, Error{ID: request.ID, NodeID: dht.node.id.RawString(), Response: []string{"No target"}})
// return
@ -244,6 +245,7 @@ func handleRequest(dht *DHT, addr *net.UDPAddr, request Request) (success bool)
default:
// send(dht, addr, makeError(t, protocolError, "invalid q"))
log.Errorln("invalid request method")
return
}
@ -282,11 +284,13 @@ func handleError(dht *DHT, addr *net.UDPAddr, e Error) (success bool) {
// send sends data to the udp.
func send(dht *DHT, addr *net.UDPAddr, data Message) error {
if req, ok := data.(Request); ok {
log.Infof("%s: Sending %s(%s)", hex.EncodeToString([]byte(req.NodeID))[:8], req.Method, argsToString(req.Args))
log.Debugf("[%s] query %s: sending request: %s(%s)", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(req.ID))[:8], req.Method, argsToString(req.Args))
} else if res, ok := data.(Response); ok {
log.Debugf("[%s] query %s: sending response: %s", dht.node.id.Hex()[:8], hex.EncodeToString([]byte(res.ID))[:8], res.Data)
} else {
log.Infof("%s: Sending %s", data.GetID(), spew.Sdump(data))
log.Debugf("[%s] %s", spew.Sdump(data))
}
encoded, err := data.Encode()
encoded, err := bencode.EncodeBytes(data)
if err != nil {
return err
}
@ -298,46 +302,15 @@ func send(dht *DHT, addr *net.UDPAddr, data Message) error {
return err
}
func getFindNodeResponse(i interface{}) (data []findNodeDatum) {
if reflect.TypeOf(i).Kind() != reflect.Slice {
return
}
v := reflect.ValueOf(i)
for i := 0; i < v.Len(); i++ {
if v.Index(i).Kind() != reflect.Interface {
continue
}
contact := v.Index(i).Elem()
if contact.Type().Kind() != reflect.Slice || contact.Len() != 3 {
continue
}
if contact.Index(0).Elem().Kind() != reflect.String ||
contact.Index(1).Elem().Kind() != reflect.String ||
!(contact.Index(2).Elem().Kind() == reflect.Int64 ||
contact.Index(2).Elem().Kind() == reflect.Int) {
continue
}
data = append(data, findNodeDatum{
ID: contact.Index(0).Elem().String(),
IP: contact.Index(1).Elem().String(),
Port: int(contact.Index(2).Elem().Int()),
})
}
return
}
func getArgs(argsInt interface{}) (args []string) {
func getArgs(argsInt interface{}) []string {
args := []string{}
if reflect.TypeOf(argsInt).Kind() == reflect.Slice {
v := reflect.ValueOf(argsInt)
for i := 0; i < v.Len(); i++ {
args = append(args, cast.ToString(v.Index(i).Interface()))
}
}
return
return args
}
func argsToString(args []string) string {

View file

@ -1,13 +1,17 @@
package dht
import (
"encoding/hex"
"testing"
"time"
"github.com/davecgh/go-spew/spew"
log "github.com/sirupsen/logrus"
"github.com/zeebo/bencode"
)
func TestPing(t *testing.T) {
log.SetLevel(log.DebugLevel)
dhtNodeID := newRandomBitmap()
testNodeID := newRandomBitmap()
@ -111,17 +115,41 @@ func TestStore(t *testing.T) {
go dht.runHandler()
messageID := newRandomBitmap().RawString()
idToStore := newRandomBitmap().RawString()
blobHashToStore := newRandomBitmap().RawString()
data, err := bencode.EncodeBytes(map[string]interface{}{
headerTypeField: requestType,
headerMessageIDField: messageID,
headerNodeIDField: testNodeID.RawString(),
headerPayloadField: "store",
headerArgsField: []string{idToStore},
})
storeRequest := Request{
ID: messageID,
NodeID: testNodeID.RawString(),
Method: storeMethod,
StoreArgs: &storeArgs{
BlobHash: blobHashToStore,
},
}
storeRequest.StoreArgs.Value.Token = "arst"
storeRequest.StoreArgs.Value.LbryID = testNodeID.RawString()
storeRequest.StoreArgs.Value.Port = 9999
_ = "64 " + // start message
"313A30 693065" + // type: 0
"313A31 3230 3A 6EB490B5788B63F0F7E6D92352024D0CBDEC2D3A" + // message id
"313A32 3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // node id
"313A33 35 3A 73746F7265" + // method
"313A34 6C" + // start args list
"3438 3A 3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468" + // block hash
"64" + // start value dict
"363A6C6272796964 3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // lbry id
"343A706F7274 69 33333333 65" + // port
"353A746F6B656E 3438 3A 17C2D8E1E48EF21567FE4AD5C8ED944B798D3B65AB58D0C9122AD6587D1B5FED472EA2CB12284CEFA1C21EFF302322BD" + // token
"65" + // end value dict
"3438 3A 7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B" + // node id
"693065" + // self store (integer)
"65" + // end args list
"65" // end message
data, err := bencode.EncodeBytes(storeRequest)
if err != nil {
panic(err)
t.Error(err)
return
}
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
@ -191,3 +219,63 @@ func TestStore(t *testing.T) {
}
}
}
func TestFindNode(t *testing.T) {
dhtNodeID := newRandomBitmap()
conn := newTestUDPConn("127.0.0.1:21217")
dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()})
dht.conn = conn
dht.listen()
go dht.runHandler()
data, _ := hex.DecodeString("64313a30693065313a3132303a2afdf2272981651a2c64e39ab7f04ec2d3b5d5d2313a3234383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b313a33383a66696e644e6f6465313a346c34383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b6565")
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
timer := time.NewTimer(3 * time.Second)
select {
case <-timer.C:
t.Error("timeout")
case resp := <-conn.writes:
var response map[string]interface{}
err := bencode.DecodeBytes(resp.data, &response)
if err != nil {
t.Error(err)
return
}
spew.Dump(response)
}
}
func TestFindValue(t *testing.T) {
dhtNodeID := newRandomBitmap()
conn := newTestUDPConn("127.0.0.1:21217")
dht := New(&Config{Address: ":21216", NodeID: dhtNodeID.Hex()})
dht.conn = conn
dht.listen()
go dht.runHandler()
data, _ := hex.DecodeString("6469306569306569316532303a7de8e57d34e316abbb5a8a8da50dcd1ad4c80e0f69326534383a7ce1b831dec8689e44f80f547d2dea171f6a625e1a4ff6c6165e645f953103dabeb068a622203f859c6c64658fd3aa3b693365393a66696e6456616c75656934656c34383aa47624b8e7ee1e54df0c45e2eb858feb0b705bd2a78d8b739be31ba188f4bd6f56b371c51fecc5280d5fd26ba4168e966565")
conn.toRead <- testUDPPacket{addr: conn.addr, data: data}
timer := time.NewTimer(3 * time.Second)
select {
case <-timer.C:
t.Error("timeout")
case resp := <-conn.writes:
var response map[string]interface{}
err := bencode.DecodeBytes(resp.data, &response)
if err != nil {
t.Error(err)
return
}
spew.Dump(response)
}
}

267
dht/message.go Normal file
View file

@ -0,0 +1,267 @@
package dht
import (
"github.com/lbryio/errors.go"
"github.com/spf13/cast"
"github.com/zeebo/bencode"
)
const (
pingMethod = "ping"
storeMethod = "store"
findNodeMethod = "findNode"
findValueMethod = "findValue"
)
const (
pingSuccessResponse = "pong"
storeSuccessResponse = "OK"
)
const (
requestType = 0
responseType = 1
errorType = 2
)
const (
// these are strings because bencode requires bytestring keys
headerTypeField = "0"
headerMessageIDField = "1" // message id is 20 bytes long
headerNodeIDField = "2" // node id is 48 bytes long
headerPayloadField = "3"
headerArgsField = "4"
)
type Message interface {
bencode.Marshaler
GetID() string
}
type Request struct {
ID string
NodeID string
Method string
Args []string
StoreArgs *storeArgs
}
func (r Request) GetID() string { return r.ID }
func (r Request) MarshalBencode() ([]byte, error) {
var args interface{}
if r.StoreArgs != nil {
args = r.StoreArgs
} else {
args = r.Args
}
return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: requestType,
headerMessageIDField: r.ID,
headerNodeIDField: r.NodeID,
headerPayloadField: r.Method,
headerArgsField: args,
})
}
func (r *Request) UnmarshalBencode(b []byte) error {
var raw struct {
ID string `bencode:"1"`
NodeID string `bencode:"2"`
Method string `bencode:"3"`
Args bencode.RawMessage `bencode:"4"`
}
err := bencode.DecodeBytes(b, &raw)
if err != nil {
return err
}
r.ID = raw.ID
r.NodeID = raw.NodeID
r.Method = raw.Method
if r.Method == storeMethod {
err = bencode.DecodeBytes(raw.Args, &r.StoreArgs)
} else {
err = bencode.DecodeBytes(raw.Args, &r.Args)
}
if err != nil {
return err
}
return nil
}
type storeArgs struct {
BlobHash string // 48 bytes
Value struct {
Token string `bencode:"token"`
LbryID string `bencode:"lbryid"`
Port int `bencode:"port"`
}
NodeID string // 48 bytes
SelfStore bool // this is an int on the wire
}
func (s *storeArgs) MarshalBencode() ([]byte, error) {
encodedValue, err := bencode.EncodeString(s.Value)
if err != nil {
return nil, err
}
selfStoreStr := 0
if s.SelfStore {
selfStoreStr = 1
}
return bencode.EncodeBytes([]interface{}{
s.BlobHash,
bencode.RawMessage(encodedValue),
s.NodeID,
selfStoreStr,
})
}
func (s *storeArgs) UnmarshalBencode(b []byte) error {
var argsInt []bencode.RawMessage
err := bencode.DecodeBytes(b, &argsInt)
if err != nil {
return err
}
if len(argsInt) != 4 {
return errors.Err("unexpected number of fields for store args. got " + cast.ToString(len(argsInt)))
}
err = bencode.DecodeBytes(argsInt[0], &s.BlobHash)
if err != nil {
return errors.Err(err)
}
err = bencode.DecodeBytes(argsInt[1], &s.Value)
if err != nil {
return errors.Err(err)
}
err = bencode.DecodeBytes(argsInt[2], &s.NodeID)
if err != nil {
return errors.Err(err)
}
var selfStore int
err = bencode.DecodeBytes(argsInt[3], &selfStore)
if err != nil {
return errors.Err(err)
}
if selfStore == 0 {
s.SelfStore = false
} else if selfStore == 1 {
s.SelfStore = true
} else {
return errors.Err("selfstore must be 1 or 0")
}
return nil
}
type findNodeDatum struct {
ID bitmap
IP string
Port int
}
func (f *findNodeDatum) UnmarshalBencode(b []byte) error {
var contact []bencode.RawMessage
err := bencode.DecodeBytes(b, &contact)
if err != nil {
return err
}
if len(contact) != 3 {
return errors.Err("invalid-sized contact")
}
err = bencode.DecodeBytes(contact[0], &f.ID)
if err != nil {
return err
}
err = bencode.DecodeBytes(contact[1], &f.IP)
if err != nil {
return err
}
err = bencode.DecodeBytes(contact[2], &f.Port)
if err != nil {
return err
}
return nil
}
type Response struct {
ID string
NodeID string
Data string
FindNodeData []findNodeDatum
}
func (r Response) GetID() string { return r.ID }
func (r Response) MarshalBencode() ([]byte, error) {
data := map[string]interface{}{
headerTypeField: responseType,
headerMessageIDField: r.ID,
headerNodeIDField: r.NodeID,
}
if r.Data != "" {
data[headerPayloadField] = r.Data
} else {
var nodes []interface{}
for _, n := range r.FindNodeData {
nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port})
}
data[headerPayloadField] = nodes
}
return bencode.EncodeBytes(data)
}
func (r *Response) UnmarshalBencode(b []byte) error {
var raw struct {
ID string `bencode:"1"`
NodeID string `bencode:"2"`
Data bencode.RawMessage `bencode:"2"`
}
err := bencode.DecodeBytes(b, &raw)
if err != nil {
return err
}
r.ID = raw.ID
r.NodeID = raw.NodeID
err = bencode.DecodeBytes(raw.Data, &r.Data)
if err != nil {
err = bencode.DecodeBytes(raw.Data, r.FindNodeData)
if err != nil {
return err
}
}
}
type Error struct {
ID string
NodeID string
Response []string
ExceptionType string
}
func (e Error) GetID() string { return e.ID }
func (e Error) MarshalBencode() ([]byte, error) {
return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: errorType,
headerMessageIDField: e.ID,
headerNodeIDField: e.NodeID,
headerPayloadField: e.ExceptionType,
headerArgsField: e.Response,
})
}

75
dht/message_test.go Normal file
View file

@ -0,0 +1,75 @@
package dht
import (
"encoding/hex"
"reflect"
"strconv"
"strings"
"testing"
log "github.com/sirupsen/logrus"
"github.com/zeebo/bencode"
)
func TestBencodeDecodeStoreArgs(t *testing.T) {
log.SetLevel(log.DebugLevel)
blobHash := "3214D6C2F77FCB5E8D5FC07EDAFBA614F031CE8B2EAB49F924F8143F6DFBADE048D918710072FB98AB1B52B58F4E1468"
lbryID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B"
port := hex.EncodeToString([]byte("3333"))
token := "17C2D8E1E48EF21567FE4AD5C8ED944B798D3B65AB58D0C9122AD6587D1B5FED472EA2CB12284CEFA1C21EFF302322BD"
nodeID := "7CE1B831DEC8689E44F80F547D2DEA171F6A625E1A4FF6C6165E645F953103DABEB068A622203F859C6C64658FD3AA3B"
selfStore := hex.EncodeToString([]byte("1"))
raw := "6C" + // start args list
"3438 3A " + blobHash + // blob hash
"64" + // start value dict
"363A6C6272796964 3438 3A " + lbryID + // lbry id
"343A706F7274 69 " + port + " 65" + // port
"353A746F6B656E 3438 3A " + token + // token
"65" + // end value dict
"3438 3A " + nodeID + // node id
"69 " + selfStore + " 65" + // self store (integer)
"65" // end args list
raw = strings.ToLower(strings.Replace(raw, " ", "", -1))
data, err := hex.DecodeString(raw)
if err != nil {
t.Error(err)
return
}
storeArgs := &storeArgs{}
err = bencode.DecodeBytes(data, storeArgs)
if err != nil {
t.Error(err)
}
if hex.EncodeToString([]byte(storeArgs.BlobHash)) != strings.ToLower(blobHash) {
t.Error("blob hash mismatch")
}
if hex.EncodeToString([]byte(storeArgs.Value.LbryID)) != strings.ToLower(lbryID) {
t.Error("lbryid mismatch")
}
if hex.EncodeToString([]byte(strconv.Itoa(storeArgs.Value.Port))) != port {
t.Error("port mismatch")
}
if hex.EncodeToString([]byte(storeArgs.Value.Token)) != strings.ToLower(token) {
t.Error("token mismatch")
}
if hex.EncodeToString([]byte(storeArgs.NodeID)) != strings.ToLower(nodeID) {
t.Error("node id mismatch")
}
if !storeArgs.SelfStore {
t.Error("selfStore mismatch")
}
reencoded, err := bencode.EncodeBytes(storeArgs)
if err != nil {
t.Error(err)
} else if !reflect.DeepEqual(reencoded, data) {
t.Error("reencoded data does not match original")
//spew.Dump(reencoded, data)
}
}

View file

@ -1,103 +0,0 @@
package dht
import "github.com/zeebo/bencode"
const (
pingMethod = "ping"
storeMethod = "store"
findNodeMethod = "findNode"
findValueMethod = "findValue"
)
const (
pingSuccessResponse = "pong"
storeSuccessResponse = "OK"
)
const (
requestType = 0
responseType = 1
errorType = 2
)
const (
// these are strings because bencode requires bytestring keys
headerTypeField = "0"
headerMessageIDField = "1"
headerNodeIDField = "2"
headerPayloadField = "3"
headerArgsField = "4"
)
type Message interface {
GetID() string
Encode() ([]byte, error)
}
type Request struct {
ID string
NodeID string
Method string
Args []string
}
func (r Request) GetID() string { return r.ID }
func (r Request) Encode() ([]byte, error) {
return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: requestType,
headerMessageIDField: r.ID,
headerNodeIDField: r.NodeID,
headerPayloadField: r.Method,
headerArgsField: r.Args,
})
}
type findNodeDatum struct {
ID string
IP string
Port int
}
type Response struct {
ID string
NodeID string
Data string
FindNodeData []findNodeDatum
}
func (r Response) GetID() string { return r.ID }
func (r Response) Encode() ([]byte, error) {
data := map[string]interface{}{
headerTypeField: responseType,
headerMessageIDField: r.ID,
headerNodeIDField: r.NodeID,
}
if r.Data != "" {
data[headerPayloadField] = r.Data
} else {
var nodes []interface{}
for _, n := range r.FindNodeData {
nodes = append(nodes, []interface{}{n.ID, n.IP, n.Port})
}
data[headerPayloadField] = nodes
}
return bencode.EncodeBytes(data)
}
type Error struct {
ID string
NodeID string
Response []string
ExceptionType string
}
func (e Error) GetID() string { return e.ID }
func (e Error) Encode() ([]byte, error) {
return bencode.EncodeBytes(map[string]interface{}{
headerTypeField: errorType,
headerMessageIDField: e.ID,
headerNodeIDField: e.NodeID,
headerPayloadField: e.ExceptionType,
headerArgsField: e.Response,
})
}

48
dht/store.go Normal file
View file

@ -0,0 +1,48 @@
package dht
import (
"sync"
"time"
)
type peer struct {
node *Node
lastPublished time.Time
originallyPublished time.Time
originalPublisherID bitmap
}
type peerStore struct {
data map[bitmap][]peer
lock sync.RWMutex
}
func newPeerStore() *peerStore {
return &peerStore{
data: map[bitmap][]peer{},
}
}
func (s *peerStore) Insert(key bitmap, node *Node, lastPublished, originallyPublished time.Time, originaPublisherID bitmap) {
s.lock.Lock()
defer s.lock.Unlock()
newPeer := peer{node: node, lastPublished: lastPublished, originallyPublished: originallyPublished, originalPublisherID: originaPublisherID}
_, ok := s.data[key]
if !ok {
s.data[key] = []peer{newPeer}
} else {
s.data[key] = append(s.data[key], newPeer)
}
}
func (s *peerStore) GetNodes(key bitmap) []*Node {
s.lock.RLock()
defer s.lock.RUnlock()
nodes := []*Node{}
if peers, ok := s.data[key]; ok {
for _, p := range peers {
nodes = append(nodes, p.node)
}
}
return nodes
}