From 734553dc2acd91e028a4e07d7aa5507815b30b51 Mon Sep 17 00:00:00 2001
From: Victor Shyba <victor.shyba@gmail.com>
Date: Thu, 20 May 2021 15:35:43 -0300
Subject: [PATCH] insert stream blobs under a transaction

---
 db/db.go | 88 ++++++++++++++++++++++++++++++++++++++++++++------------
 1 file changed, 69 insertions(+), 19 deletions(-)

diff --git a/db/db.go b/db/db.go
index ec47f7f..7a8efe4 100644
--- a/db/db.go
+++ b/db/db.go
@@ -171,6 +171,68 @@ func (s *SQL) insertBlobs(hashes []string) error {
 	return nil
 }
 
+func (s *SQL) makeInsertStreamTransaction(sdBlob SdBlob, streamID int64) (txFunc, error) {
+	return func(tx *sql.Tx) error {
+		for _, contentBlob := range sdBlob.Blobs {
+			if contentBlob.BlobHash == "" {
+				// null terminator blob
+				continue
+			}
+
+			var (
+				q    string
+				args []interface{}
+			)
+			if s.TrackAccess == TrackAccessBlobs {
+				args = []interface{}{contentBlob.BlobHash, false, contentBlob.Length, time.Now()}
+				q = "INSERT INTO blob_ (hash, is_stored, length, last_accessed_at) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored)), last_accessed_at = VALUES(last_accessed_at)"
+			} else {
+				args = []interface{}{contentBlob.BlobHash, false, contentBlob.Length}
+				q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))"
+			}
+
+			result, err := tx.Exec(q, args...)
+			if err != nil {
+				return err
+			}
+
+			blobID, err := result.LastInsertId()
+			if err != nil {
+				return err
+			}
+			if blobID == 0 {
+				err = tx.QueryRow("SELECT id FROM blob_ WHERE hash = ?", contentBlob.BlobHash).Scan(&blobID)
+				if err != nil {
+					return errors.Err(err)
+				}
+				if blobID == 0 {
+					return errors.Err("blob ID is 0 even after INSERTing and SELECTing")
+				}
+
+				if s.TrackAccess == TrackAccessBlobs {
+					err := s.touchBlobs([]uint64{uint64(blobID)})
+					if err != nil {
+						return errors.Err(err)
+					}
+				}
+			}
+			if err != nil {
+				return err
+			}
+
+			streamArgs := []interface{}{streamID, blobID, contentBlob.BlobNum}
+			_, err = tx.Exec(
+				"INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")",
+				streamArgs...,
+			)
+			if err != nil {
+				return errors.Err(err)
+			}
+		}
+		return nil
+	}, nil
+}
+
 func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error) {
 	if length <= 0 {
 		return 0, errors.Err("length must be positive")
@@ -601,25 +663,13 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
 	}
 
 	// insert content blobs and connect them to stream
-	for _, contentBlob := range sdBlob.Blobs {
-		if contentBlob.BlobHash == "" {
-			// null terminator blob
-			continue
-		}
-
-		blobID, err := s.insertBlob(contentBlob.BlobHash, contentBlob.Length, false)
-		if err != nil {
-			return err
-		}
-
-		args := []interface{}{streamID, blobID, contentBlob.BlobNum}
-		_, err = s.exec(
-			"INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")",
-			args...,
-		)
-		if err != nil {
-			return errors.Err(err)
-		}
+	insertTX, err := s.makeInsertStreamTransaction(sdBlob, streamID)
+	if err != nil {
+		return err
+	}
+	err = s.withTx(insertTX)
+	if err != nil {
+		return err
 	}
 	return nil
 }