some work on v2 of proto

This commit is contained in:
Alex Grintsvayg 2017-08-15 16:02:18 -04:00
parent 782090188e
commit 8ee6b26feb
5 changed files with 148 additions and 36 deletions

View file

@ -9,7 +9,8 @@ import (
) )
type Client struct { type Client struct {
conn net.Conn conn net.Conn
connected bool
} }
func (c *Client) Connect(address string) error { func (c *Client) Connect(address string) error {
@ -18,13 +19,19 @@ func (c *Client) Connect(address string) error {
if err != nil { if err != nil {
return err return err
} }
c.connected = true
return c.doHandshake(protocolVersion1) return c.doHandshake(protocolVersion1)
} }
func (c *Client) Close() error { func (c *Client) Close() error {
c.connected = false
return c.conn.Close() return c.conn.Close()
} }
func (c *Client) SendBlob(blob []byte) error { func (c *Client) SendBlob(blob []byte) error {
if !c.connected {
return fmt.Errorf("Not connected")
}
if len(blob) != BlobSize { if len(blob) != BlobSize {
return fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes") return fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes")
} }
@ -37,7 +44,6 @@ func (c *Client) SendBlob(blob []byte) error {
if err != nil { if err != nil {
return err return err
} }
_, err = c.conn.Write(sendRequest) _, err = c.conn.Write(sendRequest)
if err != nil { if err != nil {
return err return err
@ -75,6 +81,10 @@ func (c *Client) SendBlob(blob []byte) error {
} }
func (c *Client) doHandshake(version int) error { func (c *Client) doHandshake(version int) error {
if !c.connected {
return fmt.Errorf("Not connected")
}
handshake, err := json.Marshal(handshakeRequestResponse{Version: version}) handshake, err := json.Marshal(handshakeRequestResponse{Version: version})
if err != nil { if err != nil {
return err return err

60
client_test.go Normal file
View file

@ -0,0 +1,60 @@
package main
import (
"io/ioutil"
"math/rand"
"os"
"strconv"
"testing"
"time"
)
var address = "localhost:" + strconv.Itoa(DefaultPort)
var s Server
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
dir, err := ioutil.TempDir("", "reflector_client_test")
if err != nil {
panic(err)
}
defer os.RemoveAll(dir)
s := NewServer(dir)
go s.ListenAndServe(address)
os.Exit(m.Run())
}
func TestNotConnected(t *testing.T) {
c := Client{}
err := c.SendBlob([]byte{})
if err == nil {
t.Error("client should error if it is not connected")
}
}
func TestSmallBlob(t *testing.T) {
c := Client{}
err := c.Connect(address)
if err != nil {
t.Error(err)
}
err = c.SendBlob([]byte{})
if err == nil {
t.Error("client should error if blob is empty")
}
blob := make([]byte, 1000)
_, err = rand.Read(blob)
if err != nil {
t.Error("failed to make random blob")
}
err = c.SendBlob([]byte{})
if err == nil {
t.Error("client should error if blob is the wrong size")
}
}

16
main.go
View file

@ -4,6 +4,7 @@ import (
"flag" "flag"
"log" "log"
"math/rand" "math/rand"
"strconv"
"time" "time"
) )
@ -17,7 +18,8 @@ func main() {
var err error var err error
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
address := "localhost:5566" port := DefaultPort
address := "52.14.109.125:" + strconv.Itoa(port)
serve := flag.Bool("server", false, "Run server") serve := flag.Bool("server", false, "Run server")
blobDir := flag.String("blobdir", "", "Where blobs will be saved to") blobDir := flag.String("blobdir", "", "Where blobs will be saved to")
@ -49,16 +51,4 @@ func main() {
checkErr(err) checkErr(err)
err = client.SendBlob(blob) err = client.SendBlob(blob)
checkErr(err) checkErr(err)
blob = make([]byte, 2*1024*1024)
_, err = rand.Read(blob)
checkErr(err)
err = client.SendBlob(blob)
checkErr(err)
blob = make([]byte, 2*1024*1024)
_, err = rand.Read(blob)
checkErr(err)
err = client.SendBlob(blob)
checkErr(err)
} }

View file

@ -87,18 +87,20 @@ func (s *Server) doError(conn net.Conn, e error) error {
} }
func (s *Server) receiveBlob(conn net.Conn) error { func (s *Server) receiveBlob(conn net.Conn) error {
blobSize, blobHash, err := s.readBlobRequest(conn) blobSize, blobHash, isSdBlob, err := s.readBlobRequest(conn)
if err != nil { if err != nil {
return err return err
} }
blobExists := false blobExists := false
blobPath := path.Join(s.BlobDir, blobHash) blobPath := path.Join(s.BlobDir, blobHash)
if _, err := os.Stat(blobPath); !os.IsNotExist(err) { if !isSdBlob { // we have to say sd blobs are missing because if we say we have it, they wont try to send any content blobs
blobExists = true if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
blobExists = true
}
} }
err = s.sendBlobResponse(conn, blobExists) err = s.sendBlobResponse(conn, blobExists, isSdBlob)
if err != nil { if err != nil {
return err return err
} }
@ -116,6 +118,7 @@ func (s *Server) receiveBlob(conn net.Conn) error {
receivedBlobHash := getBlobHash(blob) receivedBlobHash := getBlobHash(blob)
if blobHash != receivedBlobHash { if blobHash != receivedBlobHash {
return fmt.Errorf("Hash of received blob data does not match hash from send request") return fmt.Errorf("Hash of received blob data does not match hash from send request")
// this can also happen if the blob size is wrong, because the server will read the wrong number of bytes from the stream
} }
log.Println("Got blob " + blobHash[:8]) log.Println("Got blob " + blobHash[:8])
@ -124,7 +127,7 @@ func (s *Server) receiveBlob(conn net.Conn) error {
return err return err
} }
return s.sendTransferResponse(conn, true) return s.sendTransferResponse(conn, true, isSdBlob)
} }
func (s *Server) doHandshake(conn net.Conn) error { func (s *Server) doHandshake(conn net.Conn) error {
@ -133,8 +136,8 @@ func (s *Server) doHandshake(conn net.Conn) error {
err := dec.Decode(&handshake) err := dec.Decode(&handshake)
if err != nil { if err != nil {
return err return err
} else if handshake.Version != protocolVersion1 { } else if handshake.Version != protocolVersion1 && handshake.Version != protocolVersion2 {
return fmt.Errorf("This server only supports protocol version 1") return fmt.Errorf("Protocol version not supported")
} }
resp, err := json.Marshal(handshakeRequestResponse{Version: handshake.Version}) resp, err := json.Marshal(handshakeRequestResponse{Version: handshake.Version})
@ -150,36 +153,74 @@ func (s *Server) doHandshake(conn net.Conn) error {
return nil return nil
} }
func (s *Server) readBlobRequest(conn net.Conn) (int, string, error) { func (s *Server) readBlobRequest(conn net.Conn) (int, string, bool, error) {
var sendRequest sendBlobRequest var sendRequest sendBlobRequest
dec := json.NewDecoder(conn) dec := json.NewDecoder(conn)
err := dec.Decode(&sendRequest) err := dec.Decode(&sendRequest)
if err != nil { if err != nil {
return 0, "", err return 0, "", false, err
} else if sendRequest.BlobSize > BlobSize {
return 0, "", fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes")
} }
return sendRequest.BlobSize, sendRequest.BlobHash, nil
if sendRequest.SdBlobHash != "" && sendRequest.BlobHash != "" {
return 0, "", false, fmt.Errorf("Invalid request")
}
var blobHash string
var blobSize int
isSdBlob := sendRequest.SdBlobHash != ""
if isSdBlob {
blobSize = sendRequest.SdBlobSize
blobHash = sendRequest.SdBlobHash
if blobSize > BlobSize {
return 0, "", isSdBlob, fmt.Errorf("SD blob cannot be more than " + strconv.Itoa(BlobSize) + " bytes")
}
} else {
blobSize = sendRequest.BlobSize
blobHash = sendRequest.BlobHash
if blobSize != BlobSize {
return 0, "", isSdBlob, fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes")
}
}
return blobSize, blobHash, isSdBlob, nil
} }
func (s *Server) sendBlobResponse(conn net.Conn, blobExists bool) error { func (s *Server) sendBlobResponse(conn net.Conn, blobExists, isSdBlob bool) error {
sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !blobExists}) var response []byte
var err error
if isSdBlob {
response, err = json.Marshal(sendSdBlobResponse{SendSdBlob: !blobExists})
} else {
response, err = json.Marshal(sendBlobResponse{SendBlob: !blobExists})
}
if err != nil { if err != nil {
return err return err
} }
_, err = conn.Write(sendResponse)
_, err = conn.Write(response)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob bool) error { func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob, isSdBlob bool) error {
transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob}) var response []byte
var err error
if isSdBlob {
response, err = json.Marshal(sdBlobTransferResponse{ReceivedSdBlob: receivedBlob})
} else {
response, err = json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob})
}
if err != nil { if err != nil {
return err return err
} }
_, err = conn.Write(transferResponse)
_, err = conn.Write(response)
if err != nil { if err != nil {
return err return err
} }

View file

@ -12,7 +12,7 @@ const (
BlobSize = 2 * 1024 * 1024 BlobSize = 2 * 1024 * 1024
protocolVersion1 = 1 protocolVersion1 = 1
protocolVersion2 = 2 // not implemented protocolVersion2 = 2
) )
var ErrBlobExists = fmt.Errorf("Blob exists on server") var ErrBlobExists = fmt.Errorf("Blob exists on server")
@ -26,18 +26,29 @@ type handshakeRequestResponse struct {
} }
type sendBlobRequest struct { type sendBlobRequest struct {
BlobHash string `json:"blob_hash"` BlobHash string `json:"blob_hash,omitempty"`
BlobSize int `json:"blob_size"` BlobSize int `json:"blob_size,omitempty"`
SdBlobHash string `json:"sd_blob_hash,omitempty"`
SdBlobSize int `json:"sd_blob_size,omitempty"`
} }
type sendBlobResponse struct { type sendBlobResponse struct {
SendBlob bool `json:"send_blob"` SendBlob bool `json:"send_blob"`
} }
type sendSdBlobResponse struct {
SendSdBlob bool `json:"send_sd_blob"`
NeededBlobs []string `json:"needed_blobs,omitempty"`
}
type blobTransferResponse struct { type blobTransferResponse struct {
ReceivedBlob bool `json:"received_blob"` ReceivedBlob bool `json:"received_blob"`
} }
type sdBlobTransferResponse struct {
ReceivedSdBlob bool `json:"received_sd_blob"`
}
func getBlobHash(blob []byte) string { func getBlobHash(blob []byte) string {
hashBytes := sha512.Sum384(blob) hashBytes := sha512.Sum384(blob)
return hex.EncodeToString(hashBytes[:]) return hex.EncodeToString(hashBytes[:])