Insert in tx #51

Closed
lyoshenka wants to merge 52 commits from insert_in_tx into master
3 changed files with 125 additions and 67 deletions
Showing only changes of commit 2a1557845d - Show all commits

View file

@ -97,7 +97,7 @@ func (s *SQL) AddBlob(hash string, length int, isStored bool) error {
return errors.Err("not connected") return errors.Err("not connected")
} }
_, err := s.insertBlob(hash, length, isStored) _, err := s.insertBlob(s.conn, hash, length, isStored)
return err return err
} }
@ -163,7 +163,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
//args = append(args, hash, true, stream.MaxBlobSize, dayAgo) //args = append(args, hash, true, stream.MaxBlobSize, dayAgo)
} }
q = strings.TrimSuffix(q, ",") q = strings.TrimSuffix(q, ",")
_, err := s.exec(q) _, err := s.exec(s.conn, q)
if err != nil { if err != nil {
return err return err
} }
@ -171,7 +171,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
return nil return nil
} }
func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error) { func (s *SQL) insertBlob(ex Executor, hash string, length int, isStored bool) (int64, error) {
if length <= 0 { if length <= 0 {
return 0, errors.Err("length must be positive") return 0, errors.Err("length must be positive")
} }
@ -188,13 +188,13 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))" q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))"
} }
blobID, err := s.exec(q, args...) blobID, err := s.exec(ex, q, args...)
if err != nil { if err != nil {
return 0, err return 0, err
} }
if blobID == 0 { if blobID == 0 {
err = s.conn.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID) err = ex.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID)
if err != nil { if err != nil {
return 0, errors.Err(err) return 0, errors.Err(err)
} }
@ -203,7 +203,7 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
} }
if s.TrackAccess == TrackAccessBlobs { if s.TrackAccess == TrackAccessBlobs {
err := s.touchBlobs([]uint64{uint64(blobID)}) err := s.touchBlobs(ex, []uint64{uint64(blobID)})
if err != nil { if err != nil {
return 0, errors.Err(err) return 0, errors.Err(err)
} }
@ -227,7 +227,7 @@ func (s *SQL) insertStream(hash string, sdBlobID int64) (int64, error) {
q = "INSERT IGNORE INTO stream (hash, sd_blob_id) VALUES (" + qt.Qs(len(args)) + ")" q = "INSERT IGNORE INTO stream (hash, sd_blob_id) VALUES (" + qt.Qs(len(args)) + ")"
} }
streamID, err := s.exec(q, args...) streamID, err := s.exec(s.conn, q, args...)
if err != nil { if err != nil {
return 0, errors.Err(err) return 0, errors.Err(err)
} }
@ -266,16 +266,16 @@ func (s *SQL) HasBlobs(hashes []string, touch bool) (map[string]bool, error) {
if touch { if touch {
if s.TrackAccess == TrackAccessBlobs { if s.TrackAccess == TrackAccessBlobs {
s.touchBlobs(idsNeedingTouch) _ = s.touchBlobs(s.conn, idsNeedingTouch)
} else if s.TrackAccess == TrackAccessStreams { } else if s.TrackAccess == TrackAccessStreams {
s.touchStreams(idsNeedingTouch) _ = s.touchStreams(idsNeedingTouch)
} }
} }
return exists, err return exists, err
} }
func (s *SQL) touchBlobs(blobIDs []uint64) error { func (s *SQL) touchBlobs(ex Executor, blobIDs []uint64) error {
if len(blobIDs) == 0 { if len(blobIDs) == 0 {
return nil return nil
} }
@ -288,7 +288,7 @@ func (s *SQL) touchBlobs(blobIDs []uint64) error {
} }
startTime := time.Now() startTime := time.Now()
_, err := s.exec(query, args...) _, err := s.exec(ex, query, args...)
log.Debugf("touched %d blobs and took %s", len(blobIDs), time.Since(startTime)) log.Debugf("touched %d blobs and took %s", len(blobIDs), time.Since(startTime))
return errors.Err(err) return errors.Err(err)
} }
@ -306,7 +306,7 @@ func (s *SQL) touchStreams(streamIDs []uint64) error {
} }
startTime := time.Now() startTime := time.Now()
_, err := s.exec(query, args...) _, err := s.exec(s.conn, query, args...)
log.Debugf("touched %d streams and took %s", len(streamIDs), time.Since(startTime)) log.Debugf("touched %d streams and took %s", len(streamIDs), time.Since(startTime))
return errors.Err(err) return errors.Err(err)
} }
@ -406,16 +406,16 @@ WHERE b.is_stored = 1 and b.hash IN (` + qt.Qs(len(batch)) + `)`
// NOTE: If SoftDelete is enabled, streams will never be deleted // NOTE: If SoftDelete is enabled, streams will never be deleted
func (s *SQL) Delete(hash string) error { func (s *SQL) Delete(hash string) error {
if s.SoftDelete { if s.SoftDelete {
_, err := s.exec("UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash) _, err := s.exec(s.conn, "UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash)
return errors.Err(err) return errors.Err(err)
} }
_, err := s.exec("DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash) _, err := s.exec(s.conn, "DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash)
if err != nil { if err != nil {
return errors.Err(err) return errors.Err(err)
} }
_, err = s.exec("DELETE FROM blob_ WHERE hash = ?", hash) _, err = s.exec(s.conn, "DELETE FROM blob_ WHERE hash = ?", hash)
return errors.Err(err) return errors.Err(err)
} }
@ -590,7 +590,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
return errors.Err("not connected") return errors.Err("not connected")
} }
sdBlobID, err := s.insertBlob(sdHash, sdBlobLength, true) sdBlobID, err := s.insertBlob(s.conn, sdHash, sdBlobLength, true)
if err != nil { if err != nil {
return err return err
} }
@ -600,6 +600,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
return err return err
} }
return withTx(s.conn, func(tx Transactor) error {
// insert content blobs and connect them to stream // insert content blobs and connect them to stream
for _, contentBlob := range sdBlob.Blobs { for _, contentBlob := range sdBlob.Blobs {
if contentBlob.BlobHash == "" { if contentBlob.BlobHash == "" {
@ -607,13 +608,13 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
continue continue
} }
blobID, err := s.insertBlob(contentBlob.BlobHash, contentBlob.Length, false) blobID, err := s.insertBlob(tx, contentBlob.BlobHash, contentBlob.Length, false)
if err != nil { if err != nil {
return err return err
} }
args := []interface{}{streamID, blobID, contentBlob.BlobNum} args := []interface{}{streamID, blobID, contentBlob.BlobNum}
_, err = s.exec( _, err = s.exec(tx,
"INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")", "INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")",
args..., args...,
) )
@ -622,6 +623,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
} }
} }
return nil return nil
})
} }
// GetHashRange gets the smallest and biggest hashes in the db // GetHashRange gets the smallest and biggest hashes in the db
@ -694,38 +696,37 @@ func (s *SQL) GetStoredHashesInRange(ctx context.Context, start, end bits.Bitmap
} }
// txFunc is a function that can be wrapped in a transaction // txFunc is a function that can be wrapped in a transaction
type txFunc func(tx *sql.Tx) error type txFunc func(tx Transactor) error
// withTx wraps a function in an sql transaction. the transaction is committed if there's no error, or rolled back if there is one. // withTx wraps a function in an sql transaction. the transaction is committed if there's
// if dbOrTx is an sql.DB, a new transaction is started // no error, or rolled back if there is one. if dbOrTx is not a Transactor (e.g. if it's
// an *sql.DB), withTx attempts to start a new transaction to use.
func withTx(dbOrTx interface{}, f txFunc) (err error) { func withTx(dbOrTx interface{}, f txFunc) (err error) {
var tx *sql.Tx var tx Transactor
var ok bool
switch t := dbOrTx.(type) { tx, ok = dbOrTx.(Transactor)
case *sql.Tx: if !ok {
tx = t tx, err = Begin(dbOrTx)
case *sql.DB:
tx, err = t.Begin()
if err != nil { if err != nil {
return err return err
} }
}
defer func() { defer func() {
if p := recover(); p != nil { if p := recover(); p != nil {
if rollBackError := tx.Rollback(); rollBackError != nil { if rollBackError := tx.Rollback(); rollBackError != nil {
log.Error("failed to rollback tx on panic - ", rollBackError) log.Error("failed to rollback tx on panic: ", rollBackError)
} }
panic(p) err = errors.Prefix("panic", p)
} else if err != nil { } else if err != nil {
if rollBackError := tx.Rollback(); rollBackError != nil { if rollBackError := tx.Rollback(); rollBackError != nil {
log.Error("failed to rollback tx on panic - ", rollBackError) log.Error("failed to rollback tx: ", rollBackError)
} }
} else { } else {
err = errors.Err(tx.Commit()) err = errors.Err(tx.Commit())
} }
}() }()
default:
return errors.Err("db or tx required")
}
return f(tx) return f(tx)
} }
@ -739,12 +740,12 @@ func closeRows(rows *sql.Rows) {
} }
} }
func (s *SQL) exec(query string, args ...interface{}) (int64, error) { func (s *SQL) exec(ex Executor, query string, args ...interface{}) (int64, error) {
s.logQuery(query, args...) s.logQuery(query, args...)
attempt, maxAttempts := 0, 3 attempt, maxAttempts := 0, 3
Retry: Retry:
attempt++ attempt++
result, err := s.conn.Exec(query, args...) result, err := ex.Exec(query, args...)
if isLockTimeoutError(err) { if isLockTimeoutError(err) {
if attempt <= maxAttempts { if attempt <= maxAttempts {
//Error 1205: Lock wait timeout exceeded; try restarting transaction //Error 1205: Lock wait timeout exceeded; try restarting transaction

45
db/interfaces.go Normal file
View file

@ -0,0 +1,45 @@
package db
import (
"database/sql"
"github.com/lbryio/lbry.go/v2/extras/errors"
)
// Executor can perform SQL queries.
type Executor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// Transactor can commit and rollback, on top of being able to execute queries.
type Transactor interface {
Commit() error
Rollback() error
Executor
}
// Begin begins a transaction
func Begin(db interface{}) (Transactor, error) {
type beginner interface {
Begin() (Transactor, error)
}
creator, ok := db.(beginner)
if ok {
return creator.Begin()
}
type sqlBeginner interface {
Begin() (*sql.Tx, error)
}
creator2, ok := db.(sqlBeginner)
if ok {
return creator2.Begin()
}
return nil, errors.Err("database does not support transactions")
}

View file

@ -1,27 +1,39 @@
package shared package shared
import ( import (
"os"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestBlobTrace_Serialize(t *testing.T) { func TestBlobTrace_Serialize(t *testing.T) {
hostName, err := os.Hostname()
require.NoError(t, err)
stack := NewBlobTrace(10*time.Second, "test") stack := NewBlobTrace(10*time.Second, "test")
stack.Stack(20*time.Second, "test2") stack.Stack(20*time.Second, "test2")
stack.Stack(30*time.Second, "test3") stack.Stack(30*time.Second, "test3")
serialized, err := stack.Serialize() serialized, err := stack.Serialize()
assert.NoError(t, err) require.NoError(t, err)
t.Log(serialized)
expected := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}" expected := `{"stacks":[{"timing":10000000000,"origin_name":"test","host_name":"` +
hostName +
`"},{"timing":20000000000,"origin_name":"test2","host_name":"` +
hostName +
`"},{"timing":30000000000,"origin_name":"test3","host_name":"` +
hostName +
`"}]}`
assert.Equal(t, expected, serialized) assert.Equal(t, expected, serialized)
} }
func TestBlobTrace_Deserialize(t *testing.T) { func TestBlobTrace_Deserialize(t *testing.T) {
serialized := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}" serialized := `{"stacks":[{"timing":10000000000,"origin_name":"test"},{"timing":20000000000,"origin_name":"test2"},{"timing":30000000000,"origin_name":"test3"}]}`
stack, err := Deserialize(serialized) stack, err := Deserialize(serialized)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, stack.Stacks, 3) assert.Len(t, stack.Stacks, 3)
assert.Equal(t, stack.Stacks[0].Timing, 10*time.Second) assert.Equal(t, stack.Stacks[0].Timing, 10*time.Second)
assert.Equal(t, stack.Stacks[1].Timing, 20*time.Second) assert.Equal(t, stack.Stacks[1].Timing, 20*time.Second)