diff --git a/db/db.go b/db/db.go index ec47f7f..7a8efe4 100644 --- a/db/db.go +++ b/db/db.go @@ -171,6 +171,68 @@ func (s *SQL) insertBlobs(hashes []string) error { return nil } +func (s *SQL) makeInsertStreamTransaction(sdBlob SdBlob, streamID int64) (txFunc, error) { + return func(tx *sql.Tx) error { + for _, contentBlob := range sdBlob.Blobs { + if contentBlob.BlobHash == "" { + // null terminator blob + continue + } + + var ( + q string + args []interface{} + ) + if s.TrackAccess == TrackAccessBlobs { + args = []interface{}{contentBlob.BlobHash, false, contentBlob.Length, time.Now()} + q = "INSERT INTO blob_ (hash, is_stored, length, last_accessed_at) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored)), last_accessed_at = VALUES(last_accessed_at)" + } else { + args = []interface{}{contentBlob.BlobHash, false, contentBlob.Length} + q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))" + } + + result, err := tx.Exec(q, args...) + if err != nil { + return err + } + + blobID, err := result.LastInsertId() + if err != nil { + return err + } + if blobID == 0 { + err = tx.QueryRow("SELECT id FROM blob_ WHERE hash = ?", contentBlob.BlobHash).Scan(&blobID) + if err != nil { + return errors.Err(err) + } + if blobID == 0 { + return errors.Err("blob ID is 0 even after INSERTing and SELECTing") + } + + if s.TrackAccess == TrackAccessBlobs { + err := s.touchBlobs([]uint64{uint64(blobID)}) + if err != nil { + return errors.Err(err) + } + } + } + if err != nil { + return err + } + + streamArgs := []interface{}{streamID, blobID, contentBlob.BlobNum} + _, err = tx.Exec( + "INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")", + streamArgs..., + ) + if err != nil { + return errors.Err(err) + } + } + return nil + }, nil +} + func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error) { if length <= 0 { return 0, errors.Err("length must be positive") @@ -601,25 +663,13 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error { } // insert content blobs and connect them to stream - for _, contentBlob := range sdBlob.Blobs { - if contentBlob.BlobHash == "" { - // null terminator blob - continue - } - - blobID, err := s.insertBlob(contentBlob.BlobHash, contentBlob.Length, false) - if err != nil { - return err - } - - args := []interface{}{streamID, blobID, contentBlob.BlobNum} - _, err = s.exec( - "INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")", - args..., - ) - if err != nil { - return errors.Err(err) - } + insertTX, err := s.makeInsertStreamTransaction(sdBlob, streamID) + if err != nil { + return err + } + err = s.withTx(insertTX) + if err != nil { + return err } return nil }