simplify server, check that hashes match

This commit is contained in:
Alex Grintsvayg 2017-08-10 18:41:39 -04:00
parent 3e99eff3ec
commit 782090188e

View file

@ -87,54 +87,44 @@ func (s *Server) doError(conn net.Conn, e error) error {
} }
func (s *Server) receiveBlob(conn net.Conn) error { func (s *Server) receiveBlob(conn net.Conn) error {
var sendRequest sendBlobRequest blobSize, blobHash, err := s.readBlobRequest(conn)
dec := json.NewDecoder(conn)
err := dec.Decode(&sendRequest)
if err != nil {
return err
} else if sendRequest.BlobSize > BlobSize {
return fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes")
}
// check if blob exists
haveBlob := false
sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !haveBlob})
if err != nil { if err != nil {
return err return err
} }
_, err = conn.Write(sendResponse) blobExists := false
blobPath := path.Join(s.BlobDir, blobHash)
if _, err := os.Stat(blobPath); !os.IsNotExist(err) {
blobExists = true
}
err = s.sendBlobResponse(conn, blobExists)
if err != nil { if err != nil {
return err return err
} }
blob := make([]byte, sendRequest.BlobSize) if blobExists {
return nil
}
blob := make([]byte, blobSize)
_, err = io.ReadFull(bufio.NewReader(conn), blob) _, err = io.ReadFull(bufio.NewReader(conn), blob)
if err != nil { if err != nil {
return err return err
} }
blobHash := getBlobHash(blob) receivedBlobHash := getBlobHash(blob)
if sendRequest.BlobHash != blobHash { if blobHash != receivedBlobHash {
return fmt.Errorf("Hash of received blob does not match hash from send request") return fmt.Errorf("Hash of received blob data does not match hash from send request")
} }
log.Println("Got blob " + blobHash[:8]) log.Println("Got blob " + blobHash[:8])
err = ioutil.WriteFile(path.Join(s.BlobDir, blobHash), blob, 0644) err = ioutil.WriteFile(blobPath, blob, 0644)
if err != nil { if err != nil {
return err return err
} }
transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: true}) return s.sendTransferResponse(conn, true)
if err != nil {
return err
}
_, err = conn.Write(transferResponse)
if err != nil {
return err
}
return nil
} }
func (s *Server) doHandshake(conn net.Conn) error { func (s *Server) doHandshake(conn net.Conn) error {
@ -160,6 +150,42 @@ func (s *Server) doHandshake(conn net.Conn) error {
return nil return nil
} }
func (s *Server) readBlobRequest(conn net.Conn) (int, string, error) {
var sendRequest sendBlobRequest
dec := json.NewDecoder(conn)
err := dec.Decode(&sendRequest)
if err != nil {
return 0, "", err
} else if sendRequest.BlobSize > BlobSize {
return 0, "", fmt.Errorf("Blob size cannot be greater than " + strconv.Itoa(BlobSize) + " bytes")
}
return sendRequest.BlobSize, sendRequest.BlobHash, nil
}
func (s *Server) sendBlobResponse(conn net.Conn, blobExists bool) error {
sendResponse, err := json.Marshal(sendBlobResponse{SendBlob: !blobExists})
if err != nil {
return err
}
_, err = conn.Write(sendResponse)
if err != nil {
return err
}
return nil
}
func (s *Server) sendTransferResponse(conn net.Conn, receivedBlob bool) error {
transferResponse, err := json.Marshal(blobTransferResponse{ReceivedBlob: receivedBlob})
if err != nil {
return err
}
_, err = conn.Write(transferResponse)
if err != nil {
return err
}
return nil
}
func (s *Server) ensureBlobDirExists() error { func (s *Server) ensureBlobDirExists() error {
if stat, err := os.Stat(s.BlobDir); err != nil { if stat, err := os.Stat(s.BlobDir); err != nil {
if os.IsNotExist(err) { if os.IsNotExist(err) {