diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..65444fc --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/.idea +/blobs diff --git a/client.go b/client.go new file mode 100644 index 0000000..f269770 --- /dev/null +++ b/client.go @@ -0,0 +1,98 @@ +package main + +import ( + "encoding/json" + "fmt" + "log" + "net" + "strconv" +) + +type Client struct { + conn net.Conn +} + +func (c *Client) Connect(address string) error { + var err error + c.conn, err = net.Dial("tcp", address) + if err != nil { + return err + } + return c.doHandshake(protocolVersion1) +} +func (c *Client) Close() error { + return c.conn.Close() +} + +func (c *Client) SendBlob(blob []byte) error { + if len(blob) != BlobSize { + return fmt.Errorf("Blob must be exactly " + strconv.Itoa(BlobSize) + " bytes") + } + + blobHash := getBlobHash(blob) + sendRequest, err := json.Marshal(sendBlobRequest{ + BlobSize: len(blob), + BlobHash: blobHash, + }) + if err != nil { + return err + } + + _, err = c.conn.Write(sendRequest) + if err != nil { + return err + } + + dec := json.NewDecoder(c.conn) + + var sendResp sendBlobResponse + err = dec.Decode(&sendResp) + if err != nil { + return err + } + + if !sendResp.SendBlob { + return ErrBlobExists + } + + log.Println("Sending blob " + blobHash[:8]) + + _, err = c.conn.Write(blob) + if err != nil { + return err + } + var transferResp blobTransferResponse + err = dec.Decode(&transferResp) + if err != nil { + return err + } + + if !transferResp.ReceivedBlob { + return fmt.Errorf("Server did not received blob") + } + + return nil +} + +func (c *Client) doHandshake(version int) error { + handshake, err := json.Marshal(handshakeRequestResponse{Version: version}) + if err != nil { + return err + } + + _, err = c.conn.Write(handshake) + if err != nil { + return err + } + + var resp handshakeRequestResponse + dec := json.NewDecoder(c.conn) + err = dec.Decode(&resp) + if err != nil { + return err + } else if resp.Version != version { + return fmt.Errorf("Handshake version mismatch") + } + + return nil +} diff --git a/main.go b/main.go new file mode 100644 index 0000000..deb168c --- /dev/null +++ b/main.go @@ -0,0 +1,64 @@ +package main + +import ( + "flag" + "log" + "math/rand" + "time" +) + +func checkErr(err error) { + if err != nil { + panic(err) + } +} + +func main() { + var err error + rand.Seed(time.Now().UnixNano()) + + address := "localhost:5566" + + serve := flag.Bool("server", false, "Run server") + blobDir := flag.String("blobdir", "", "Where blobs will be saved to") + flag.Parse() + if serve != nil && *serve { + if blobDir == nil || *blobDir == "" { + log.Fatal("-blobdir required") + } + server := NewServer(*blobDir) + log.Fatal(server.ListenAndServe(address)) + return + } + + client := Client{} + + log.Println("Connecting to " + address) + err = client.Connect(address) + checkErr(err) + + log.Println("Connected") + + defer func() { + log.Println("Closing connection") + client.Close() + }() + + blob := make([]byte, 2*1024*1024) + _, err = rand.Read(blob) + checkErr(err) + err = client.SendBlob(blob) + checkErr(err) + + blob = make([]byte, 2*1024*1024) + _, err = rand.Read(blob) + checkErr(err) + err = client.SendBlob(blob) + checkErr(err) + + blob = make([]byte, 2*1024*1024) + _, err = rand.Read(blob) + checkErr(err) + err = client.SendBlob(blob) + checkErr(err) +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..9b5554d --- /dev/null +++ b/server.go @@ -0,0 +1,177 @@ +package main + +import ( + "bufio" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "log" + "net" + "os" + "path" + "strconv" +) + +type Server struct { + BlobDir string +} + +func NewServer(blobDir string) *Server { + return &Server{ + BlobDir: blobDir, + } +} + +func (s *Server) ListenAndServe(address string) error { + log.Println("Blobs will be saved to " + s.BlobDir) + err := s.ensureBlobDirExists() + if err != nil { + return err + } + + log.Println("Listening on " + address) + l, err := net.Listen("tcp", address) + if err != nil { + return err + } + defer l.Close() + + for { + conn, err := l.Accept() + if err != nil { + // TODO: dont crash server on error here + return err + } + go s.handleConn(conn) + } +} + +func (s *Server) handleConn(conn net.Conn) { + // TODO: connection should time out eventually + defer conn.Close() + + err := s.doHandshake(conn) + if err != nil { + if err == io.EOF { + return + } + s.doError(conn, err) + return + } + + for { + err = s.receiveBlob(conn) + if err != nil { + if err == io.EOF { + return + } + s.doError(conn, err) + return + } + + } +} + +func (s *Server) doError(conn net.Conn, e error) error { + log.Println("Error: " + e.Error()) + if e2, ok := e.(*json.SyntaxError); ok { + log.Printf("syntax error at byte offset %d", e2.Offset) + } + resp, err := json.Marshal(errorResponse{Error: e.Error()}) + if err != nil { + return err + } + _, err = conn.Write(resp) + return err +} + +func (s *Server) receiveBlob(conn net.Conn) error { + var sendRequest sendBlobRequest + dec := json.NewDecoder(conn) + err := dec.Decode(&sendRequest) + if err != nil { + return err + } else if sendRequest.BlobSize > BlobSize { + return fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes") + } + + // check if blob exists + haveBlob := false + sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !haveBlob}) + if err != nil { + return err + } + + _, err = conn.Write(sendResponse) + if err != nil { + return err + } + + blob := make([]byte, sendRequest.BlobSize) + _, err = io.ReadFull(bufio.NewReader(conn), blob) + if err != nil { + return err + } + + blobHash := getBlobHash(blob) + if sendRequest.BlobHash != blobHash { + return fmt.Errorf("Hash of received blob does not match hash from send request") + } + log.Println("Got blob " + blobHash[:8]) + + err = ioutil.WriteFile(path.Join(s.BlobDir, blobHash), blob, 0644) + if err != nil { + return err + } + + transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: true}) + if err != nil { + return err + } + + _, err = conn.Write(transferResponse) + if err != nil { + return err + } + return nil +} + +func (s *Server) doHandshake(conn net.Conn) error { + var handshake handshakeRequestResponse + dec := json.NewDecoder(conn) + err := dec.Decode(&handshake) + if err != nil { + return err + } else if handshake.Version != protocolVersion1 { + return fmt.Errorf("This server only supports protocol version 1") + } + + resp, err := json.Marshal(handshakeRequestResponse{Version: handshake.Version}) + if err != nil { + return err + } + + _, err = conn.Write(resp) + if err != nil { + return err + } + + return nil +} + +func (s *Server) ensureBlobDirExists() error { + if stat, err := os.Stat(s.BlobDir); err != nil { + if os.IsNotExist(err) { + err2 := os.Mkdir(s.BlobDir, 0755) + if err2 != nil { + return err2 + } + } else { + return err + } + } else if !stat.IsDir() { + return fmt.Errorf("blob dir exists but is not a dir") + } + return nil +} diff --git a/shared.go b/shared.go new file mode 100644 index 0000000..689eeab --- /dev/null +++ b/shared.go @@ -0,0 +1,44 @@ +package main + +import ( + "crypto/sha512" + "encoding/hex" + "fmt" +) + +const ( + DefaultPort = 5566 + + BlobSize = 2 * 1024 * 1024 + + protocolVersion1 = 1 + protocolVersion2 = 2 // not implemented +) + +var ErrBlobExists = fmt.Errorf("Blob exists on server") + +type errorResponse struct { + Error string `json:"error"` +} + +type handshakeRequestResponse struct { + Version int `json:"version"` +} + +type sendBlobRequest struct { + BlobHash string `json:"blob_hash"` + BlobSize int `json:"blob_size"` +} + +type sendBlobResponse struct { + SendBlob bool `json:"send_blob"` +} + +type blobTransferResponse struct { + ReceivedBlob bool `json:"received_blob"` +} + +func getBlobHash(blob []byte) string { + hashBytes := sha512.Sum384(blob) + return hex.EncodeToString(hashBytes[:]) +}