Insert in tx #51

Closed
lyoshenka wants to merge 52 commits from insert_in_tx into master
3 changed files with 125 additions and 67 deletions
Showing only changes of commit 2a1557845d - Show all commits

View file

@ -97,7 +97,7 @@ func (s *SQL) AddBlob(hash string, length int, isStored bool) error {
return errors.Err("not connected")
}
_, err := s.insertBlob(hash, length, isStored)
_, err := s.insertBlob(s.conn, hash, length, isStored)
return err
}
@ -163,7 +163,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
//args = append(args, hash, true, stream.MaxBlobSize, dayAgo)
}
q = strings.TrimSuffix(q, ",")
_, err := s.exec(q)
_, err := s.exec(s.conn, q)
if err != nil {
return err
}
@ -171,7 +171,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
return nil
}
func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error) {
func (s *SQL) insertBlob(ex Executor, hash string, length int, isStored bool) (int64, error) {
if length <= 0 {
return 0, errors.Err("length must be positive")
}
@ -188,13 +188,13 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))"
}
blobID, err := s.exec(q, args...)
blobID, err := s.exec(ex, q, args...)
if err != nil {
return 0, err
}
if blobID == 0 {
err = s.conn.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID)
err = ex.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID)
if err != nil {
return 0, errors.Err(err)
}
@ -203,7 +203,7 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
}
if s.TrackAccess == TrackAccessBlobs {
err := s.touchBlobs([]uint64{uint64(blobID)})
err := s.touchBlobs(ex, []uint64{uint64(blobID)})
if err != nil {
return 0, errors.Err(err)
}
@ -227,7 +227,7 @@ func (s *SQL) insertStream(hash string, sdBlobID int64) (int64, error) {
q = "INSERT IGNORE INTO stream (hash, sd_blob_id) VALUES (" + qt.Qs(len(args)) + ")"
}
streamID, err := s.exec(q, args...)
streamID, err := s.exec(s.conn, q, args...)
if err != nil {
return 0, errors.Err(err)
}
@ -266,16 +266,16 @@ func (s *SQL) HasBlobs(hashes []string, touch bool) (map[string]bool, error) {
if touch {
if s.TrackAccess == TrackAccessBlobs {
s.touchBlobs(idsNeedingTouch)
_ = s.touchBlobs(s.conn, idsNeedingTouch)
} else if s.TrackAccess == TrackAccessStreams {
s.touchStreams(idsNeedingTouch)
_ = s.touchStreams(idsNeedingTouch)
}
}
return exists, err
}
func (s *SQL) touchBlobs(blobIDs []uint64) error {
func (s *SQL) touchBlobs(ex Executor, blobIDs []uint64) error {
if len(blobIDs) == 0 {
return nil
}
@ -288,7 +288,7 @@ func (s *SQL) touchBlobs(blobIDs []uint64) error {
}
startTime := time.Now()
_, err := s.exec(query, args...)
_, err := s.exec(ex, query, args...)
log.Debugf("touched %d blobs and took %s", len(blobIDs), time.Since(startTime))
return errors.Err(err)
}
@ -306,7 +306,7 @@ func (s *SQL) touchStreams(streamIDs []uint64) error {
}
startTime := time.Now()
_, err := s.exec(query, args...)
_, err := s.exec(s.conn, query, args...)
log.Debugf("touched %d streams and took %s", len(streamIDs), time.Since(startTime))
return errors.Err(err)
}
@ -406,16 +406,16 @@ WHERE b.is_stored = 1 and b.hash IN (` + qt.Qs(len(batch)) + `)`
// NOTE: If SoftDelete is enabled, streams will never be deleted
func (s *SQL) Delete(hash string) error {
if s.SoftDelete {
_, err := s.exec("UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash)
_, err := s.exec(s.conn, "UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash)
return errors.Err(err)
}
_, err := s.exec("DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash)
_, err := s.exec(s.conn, "DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash)
if err != nil {
return errors.Err(err)
}
_, err = s.exec("DELETE FROM blob_ WHERE hash = ?", hash)
_, err = s.exec(s.conn, "DELETE FROM blob_ WHERE hash = ?", hash)
return errors.Err(err)
}
@ -590,7 +590,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
return errors.Err("not connected")
}
sdBlobID, err := s.insertBlob(sdHash, sdBlobLength, true)
sdBlobID, err := s.insertBlob(s.conn, sdHash, sdBlobLength, true)
if err != nil {
return err
}
@ -600,6 +600,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
return err
}
return withTx(s.conn, func(tx Transactor) error {
// insert content blobs and connect them to stream
for _, contentBlob := range sdBlob.Blobs {
if contentBlob.BlobHash == "" {
@ -607,13 +608,13 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
continue
}
blobID, err := s.insertBlob(contentBlob.BlobHash, contentBlob.Length, false)
blobID, err := s.insertBlob(tx, contentBlob.BlobHash, contentBlob.Length, false)
if err != nil {
return err
}
args := []interface{}{streamID, blobID, contentBlob.BlobNum}
_, err = s.exec(
_, err = s.exec(tx,
"INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")",
args...,
)
@ -622,6 +623,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
}
}
return nil
})
}
// GetHashRange gets the smallest and biggest hashes in the db
@ -694,38 +696,37 @@ func (s *SQL) GetStoredHashesInRange(ctx context.Context, start, end bits.Bitmap
}
// txFunc is a function that can be wrapped in a transaction
type txFunc func(tx *sql.Tx) error
type txFunc func(tx Transactor) error
// withTx wraps a function in an sql transaction. the transaction is committed if there's no error, or rolled back if there is one.
// if dbOrTx is an sql.DB, a new transaction is started
// withTx wraps a function in an sql transaction. the transaction is committed if there's
// no error, or rolled back if there is one. if dbOrTx is not a Transactor (e.g. if it's
// an *sql.DB), withTx attempts to start a new transaction to use.
func withTx(dbOrTx interface{}, f txFunc) (err error) {
var tx *sql.Tx
var tx Transactor
var ok bool
switch t := dbOrTx.(type) {
case *sql.Tx:
tx = t
case *sql.DB:
tx, err = t.Begin()
tx, ok = dbOrTx.(Transactor)
if !ok {
tx, err = Begin(dbOrTx)
if err != nil {
return err
}
}
defer func() {
if p := recover(); p != nil {
if rollBackError := tx.Rollback(); rollBackError != nil {
log.Error("failed to rollback tx on panic - ", rollBackError)
log.Error("failed to rollback tx on panic: ", rollBackError)
}
panic(p)
err = errors.Prefix("panic", p)
} else if err != nil {
if rollBackError := tx.Rollback(); rollBackError != nil {
log.Error("failed to rollback tx on panic - ", rollBackError)
log.Error("failed to rollback tx: ", rollBackError)
}
} else {
err = errors.Err(tx.Commit())
}
}()
default:
return errors.Err("db or tx required")
}
return f(tx)
}
@ -739,12 +740,12 @@ func closeRows(rows *sql.Rows) {
}
}
func (s *SQL) exec(query string, args ...interface{}) (int64, error) {
func (s *SQL) exec(ex Executor, query string, args ...interface{}) (int64, error) {
s.logQuery(query, args...)
attempt, maxAttempts := 0, 3
Retry:
attempt++
result, err := s.conn.Exec(query, args...)
result, err := ex.Exec(query, args...)
if isLockTimeoutError(err) {
if attempt <= maxAttempts {
//Error 1205: Lock wait timeout exceeded; try restarting transaction

45
db/interfaces.go Normal file
View file

@ -0,0 +1,45 @@
package db
import (
"database/sql"
"github.com/lbryio/lbry.go/v2/extras/errors"
)
// Executor can perform SQL queries.
type Executor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// Transactor can commit and rollback, on top of being able to execute queries.
type Transactor interface {
Commit() error
Rollback() error
Executor
}
// Begin begins a transaction
func Begin(db interface{}) (Transactor, error) {
type beginner interface {
Begin() (Transactor, error)
}
creator, ok := db.(beginner)
if ok {
return creator.Begin()
}
type sqlBeginner interface {
Begin() (*sql.Tx, error)
}
creator2, ok := db.(sqlBeginner)
if ok {
return creator2.Begin()
}
return nil, errors.Err("database does not support transactions")
}

View file

@ -1,27 +1,39 @@
package shared
import (
"os"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestBlobTrace_Serialize(t *testing.T) {
hostName, err := os.Hostname()
require.NoError(t, err)
stack := NewBlobTrace(10*time.Second, "test")
stack.Stack(20*time.Second, "test2")
stack.Stack(30*time.Second, "test3")
serialized, err := stack.Serialize()
assert.NoError(t, err)
t.Log(serialized)
expected := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}"
require.NoError(t, err)
expected := `{"stacks":[{"timing":10000000000,"origin_name":"test","host_name":"` +
hostName +
`"},{"timing":20000000000,"origin_name":"test2","host_name":"` +
hostName +
`"},{"timing":30000000000,"origin_name":"test3","host_name":"` +
hostName +
`"}]}`
assert.Equal(t, expected, serialized)
}
func TestBlobTrace_Deserialize(t *testing.T) {
serialized := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}"
serialized := `{"stacks":[{"timing":10000000000,"origin_name":"test"},{"timing":20000000000,"origin_name":"test2"},{"timing":30000000000,"origin_name":"test3"}]}`
stack, err := Deserialize(serialized)
assert.NoError(t, err)
require.NoError(t, err)
assert.Len(t, stack.Stacks, 3)
assert.Equal(t, stack.Stacks[0].Timing, 10*time.Second)
assert.Equal(t, stack.Stacks[1].Timing, 20*time.Second)