some work on v2 of proto

This commit is contained in:
Alex Grintsvayg 2017-08-15 16:02:18 -04:00
parent 782090188e
commit 8ee6b26feb
5 changed files with 148 additions and 36 deletions

View file

@ -9,7 +9,8 @@ import (
)
type Client struct {
conn net.Conn
conn net.Conn
connected bool
}
func (c *Client) Connect(address string) error {
@ -18,13 +19,19 @@ func (c *Client) Connect(address string) error {
if err != nil {
return err
}
c.connected = true
return c.doHandshake(protocolVersion1)
}
func (c *Client) Close() error {
c.connected = false
return c.conn.Close()
}
func (c *Client) SendBlob(blob []byte) error {
if !c.connected {
return fmt.Errorf("Not connected")
}
if len(blob) != BlobSize {
return fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes")
}
@ -37,7 +44,6 @@ func (c *Client) SendBlob(blob []byte) error {
if err != nil {
return err
}
_, err = c.conn.Write(sendRequest)
if err != nil {
return err
@ -75,6 +81,10 @@ func (c *Client) SendBlob(blob []byte) error {
}
func (c *Client) doHandshake(version int) error {
if !c.connected {
return fmt.Errorf("Not connected")
}
handshake, err := json.Marshal(handshakeRequestResponse{Version: version})
if err != nil {
return err

60
client_test.go Normal file
View file

@ -0,0 +1,60 @@
package main
import (
"io/ioutil"
"math/rand"
"os"
"strconv"
"testing"
"time"
)
var address = "localhost:" + strconv.Itoa(DefaultPort)
var s Server
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
dir, err := ioutil.TempDir("", "reflector_client_test")
if err != nil {
panic(err)
}
defer os.RemoveAll(dir)
s := NewServer(dir)
go s.ListenAndServe(address)
os.Exit(m.Run())
}
func TestNotConnected(t *testing.T) {
c := Client{}
err := c.SendBlob([]byte{})
if err == nil {
t.Error("client should error if it is not connected")
}
}
func TestSmallBlob(t *testing.T) {
c := Client{}
err := c.Connect(address)
if err != nil {
t.Error(err)
}
err = c.SendBlob([]byte{})
if err == nil {
t.Error("client should error if blob is empty")
}
blob := make([]byte, 1000)
_, err = rand.Read(blob)
if err != nil {
t.Error("failed to make random blob")
}
err = c.SendBlob([]byte{})
if err == nil {
t.Error("client should error if blob is the wrong size")
}
}

16
main.go
View file

@ -4,6 +4,7 @@ import (
"flag"
"log"
"math/rand"
"strconv"
"time"
)
@ -17,7 +18,8 @@ func main() {
var err error
rand.Seed(time.Now().UnixNano())
address := "localhost:5566"
port := DefaultPort
address := "52.14.109.125:" + strconv.Itoa(port)
serve := flag.Bool("server", false, "Run server")
blobDir := flag.String("blobdir", "", "Where blobs will be saved to")
@ -49,16 +51,4 @@ func main() {
checkErr(err)
err = client.SendBlob(blob)
checkErr(err)
blob = make([]byte, 2*1024*1024)
_, err = rand.Read(blob)
checkErr(err)
err = client.SendBlob(blob)
checkErr(err)
blob = make([]byte, 2*1024*1024)
_, err = rand.Read(blob)
checkErr(err)
err = client.SendBlob(blob)
checkErr(err)
}

View file

@ -87,18 +87,20 @@ func (s *Server) doError(conn net.Conn, e error) error {
}
func (s *Server) receiveBlob(conn net.Conn) error {
blobSize, blobHash, err := s.readBlobRequest(conn)
blobSize, blobHash, isSdBlob, err := s.readBlobRequest(conn)
if err != nil {
return err
}
blobExists := false
blobPath := path.Join(s.BlobDir, blobHash)
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
blobExists = true
if !isSdBlob { // we have to say sd blobs are missing because if we say we have it, they wont try to send any content blobs
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
blobExists = true
}
}
err = s.sendBlobResponse(conn, blobExists)
err = s.sendBlobResponse(conn, blobExists, isSdBlob)
if err != nil {
return err
}
@ -116,6 +118,7 @@ func (s *Server) receiveBlob(conn net.Conn) error {
receivedBlobHash := getBlobHash(blob)
if blobHash != receivedBlobHash {
return fmt.Errorf("Hash of received blob data does not match hash from send request")
// this can also happen if the blob size is wrong, because the server will read the wrong number of bytes from the stream
}
log.Println("Got blob " + blobHash[:8])
@ -124,7 +127,7 @@ func (s *Server) receiveBlob(conn net.Conn) error {
return err
}
return s.sendTransferResponse(conn, true)
return s.sendTransferResponse(conn, true, isSdBlob)
}
func (s *Server) doHandshake(conn net.Conn) error {
@ -133,8 +136,8 @@ func (s *Server) doHandshake(conn net.Conn) error {
err := dec.Decode(&handshake)
if err != nil {
return err
} else if handshake.Version != protocolVersion1 {
return fmt.Errorf("This server only supports protocol version 1")
} else if handshake.Version != protocolVersion1 && handshake.Version != protocolVersion2 {
return fmt.Errorf("Protocol version not supported")
}
resp, err := json.Marshal(handshakeRequestResponse{Version: handshake.Version})
@ -150,36 +153,74 @@ func (s *Server) doHandshake(conn net.Conn) error {
return nil
}
func (s *Server) readBlobRequest(conn net.Conn) (int, string, error) {
func (s *Server) readBlobRequest(conn net.Conn) (int, string, bool, error) {
var sendRequest sendBlobRequest
dec := json.NewDecoder(conn)
err := dec.Decode(&sendRequest)
if err != nil {
return 0, "", err
} else if sendRequest.BlobSize > BlobSize {
return 0, "", fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes")
return 0, "", false, err
}
return sendRequest.BlobSize, sendRequest.BlobHash, nil
if sendRequest.SdBlobHash != "" && sendRequest.BlobHash != "" {
return 0, "", false, fmt.Errorf("Invalid request")
}
var blobHash string
var blobSize int
isSdBlob := sendRequest.SdBlobHash != ""
if isSdBlob {
blobSize = sendRequest.SdBlobSize
blobHash = sendRequest.SdBlobHash
if blobSize > BlobSize {
return 0, "", isSdBlob, fmt.Errorf("SD blob cannot be more than " + strconv.Itoa(BlobSize) + " bytes")
}
} else {
blobSize = sendRequest.BlobSize
blobHash = sendRequest.BlobHash
if blobSize != BlobSize {
return 0, "", isSdBlob, fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes")
}
}
return blobSize, blobHash, isSdBlob, nil
}
func (s *Server) sendBlobResponse(conn net.Conn, blobExists bool) error {
sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !blobExists})
func (s *Server) sendBlobResponse(conn net.Conn, blobExists, isSdBlob bool) error {
var response []byte
var err error
if isSdBlob {
response, err = json.Marshal(sendSdBlobResponse{SendSdBlob: !blobExists})
} else {
response, err = json.Marshal(sendBlobResponse{SendBlob: !blobExists})
}
if err != nil {
return err
}
_, err = conn.Write(sendResponse)
_, err = conn.Write(response)
if err != nil {
return err
}
return nil
}
func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob bool) error {
transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob})
func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob, isSdBlob bool) error {
var response []byte
var err error
if isSdBlob {
response, err = json.Marshal(sdBlobTransferResponse{ReceivedSdBlob: receivedBlob})
} else {
response, err = json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob})
}
if err != nil {
return err
}
_, err = conn.Write(transferResponse)
_, err = conn.Write(response)
if err != nil {
return err
}

View file

@ -12,7 +12,7 @@ const (
BlobSize = 2 * 1024 * 1024
protocolVersion1 = 1
protocolVersion2 = 2 // not implemented
protocolVersion2 = 2
)
var ErrBlobExists = fmt.Errorf("Blob exists on server")
@ -26,18 +26,29 @@ type handshakeRequestResponse struct {
}
type sendBlobRequest struct {
BlobHash string `json:"blob_hash"`
BlobSize int `json:"blob_size"`
BlobHash string `json:"blob_hash,omitempty"`
BlobSize int `json:"blob_size,omitempty"`
SdBlobHash string `json:"sd_blob_hash,omitempty"`
SdBlobSize int `json:"sd_blob_size,omitempty"`
}
type sendBlobResponse struct {
SendBlob bool `json:"send_blob"`
}
type sendSdBlobResponse struct {
SendSdBlob bool `json:"send_sd_blob"`
NeededBlobs []string `json:"needed_blobs,omitempty"`
}
type blobTransferResponse struct {
ReceivedBlob bool `json:"received_blob"`
}
type sdBlobTransferResponse struct {
ReceivedSdBlob bool `json:"received_sd_blob"`
}
func getBlobHash(blob []byte) string {
hashBytes := sha512.Sum384(blob)
return hex.EncodeToString(hashBytes[:])