Insert in tx #51
3 changed files with 125 additions and 67 deletions
71
db/db.go
71
db/db.go
|
@ -97,7 +97,7 @@ func (s *SQL) AddBlob(hash string, length int, isStored bool) error {
|
|||
return errors.Err("not connected")
|
||||
}
|
||||
|
||||
_, err := s.insertBlob(hash, length, isStored)
|
||||
_, err := s.insertBlob(s.conn, hash, length, isStored)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -163,7 +163,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
|
|||
//args = append(args, hash, true, stream.MaxBlobSize, dayAgo)
|
||||
}
|
||||
q = strings.TrimSuffix(q, ",")
|
||||
_, err := s.exec(q)
|
||||
_, err := s.exec(s.conn, q)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -171,7 +171,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error) {
|
||||
func (s *SQL) insertBlob(ex Executor, hash string, length int, isStored bool) (int64, error) {
|
||||
if length <= 0 {
|
||||
return 0, errors.Err("length must be positive")
|
||||
}
|
||||
|
@ -188,13 +188,13 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
|
|||
q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))"
|
||||
}
|
||||
|
||||
blobID, err := s.exec(q, args...)
|
||||
blobID, err := s.exec(ex, q, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if blobID == 0 {
|
||||
err = s.conn.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID)
|
||||
err = ex.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID)
|
||||
if err != nil {
|
||||
return 0, errors.Err(err)
|
||||
}
|
||||
|
@ -203,7 +203,7 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
|
|||
}
|
||||
|
||||
if s.TrackAccess == TrackAccessBlobs {
|
||||
err := s.touchBlobs([]uint64{uint64(blobID)})
|
||||
err := s.touchBlobs(ex, []uint64{uint64(blobID)})
|
||||
if err != nil {
|
||||
return 0, errors.Err(err)
|
||||
}
|
||||
|
@ -227,7 +227,7 @@ func (s *SQL) insertStream(hash string, sdBlobID int64) (int64, error) {
|
|||
q = "INSERT IGNORE INTO stream (hash, sd_blob_id) VALUES (" + qt.Qs(len(args)) + ")"
|
||||
}
|
||||
|
||||
streamID, err := s.exec(q, args...)
|
||||
streamID, err := s.exec(s.conn, q, args...)
|
||||
if err != nil {
|
||||
return 0, errors.Err(err)
|
||||
}
|
||||
|
@ -266,16 +266,16 @@ func (s *SQL) HasBlobs(hashes []string, touch bool) (map[string]bool, error) {
|
|||
|
||||
if touch {
|
||||
if s.TrackAccess == TrackAccessBlobs {
|
||||
s.touchBlobs(idsNeedingTouch)
|
||||
_ = s.touchBlobs(s.conn, idsNeedingTouch)
|
||||
} else if s.TrackAccess == TrackAccessStreams {
|
||||
s.touchStreams(idsNeedingTouch)
|
||||
_ = s.touchStreams(idsNeedingTouch)
|
||||
}
|
||||
}
|
||||
|
||||
return exists, err
|
||||
}
|
||||
|
||||
func (s *SQL) touchBlobs(blobIDs []uint64) error {
|
||||
func (s *SQL) touchBlobs(ex Executor, blobIDs []uint64) error {
|
||||
if len(blobIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -288,7 +288,7 @@ func (s *SQL) touchBlobs(blobIDs []uint64) error {
|
|||
}
|
||||
|
||||
startTime := time.Now()
|
||||
_, err := s.exec(query, args...)
|
||||
_, err := s.exec(ex, query, args...)
|
||||
log.Debugf("touched %d blobs and took %s", len(blobIDs), time.Since(startTime))
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
@ -306,7 +306,7 @@ func (s *SQL) touchStreams(streamIDs []uint64) error {
|
|||
}
|
||||
|
||||
startTime := time.Now()
|
||||
_, err := s.exec(query, args...)
|
||||
_, err := s.exec(s.conn, query, args...)
|
||||
log.Debugf("touched %d streams and took %s", len(streamIDs), time.Since(startTime))
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
@ -406,16 +406,16 @@ WHERE b.is_stored = 1 and b.hash IN (` + qt.Qs(len(batch)) + `)`
|
|||
// NOTE: If SoftDelete is enabled, streams will never be deleted
|
||||
func (s *SQL) Delete(hash string) error {
|
||||
if s.SoftDelete {
|
||||
_, err := s.exec("UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash)
|
||||
_, err := s.exec(s.conn, "UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash)
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
_, err := s.exec("DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash)
|
||||
_, err := s.exec(s.conn, "DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash)
|
||||
if err != nil {
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
_, err = s.exec("DELETE FROM blob_ WHERE hash = ?", hash)
|
||||
_, err = s.exec(s.conn, "DELETE FROM blob_ WHERE hash = ?", hash)
|
||||
return errors.Err(err)
|
||||
}
|
||||
|
||||
|
@ -590,7 +590,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
|||
return errors.Err("not connected")
|
||||
}
|
||||
|
||||
sdBlobID, err := s.insertBlob(sdHash, sdBlobLength, true)
|
||||
sdBlobID, err := s.insertBlob(s.conn, sdHash, sdBlobLength, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -600,6 +600,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
|||
return err
|
||||
}
|
||||
|
||||
return withTx(s.conn, func(tx Transactor) error {
|
||||
// insert content blobs and connect them to stream
|
||||
for _, contentBlob := range sdBlob.Blobs {
|
||||
if contentBlob.BlobHash == "" {
|
||||
|
@ -607,13 +608,13 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
|||
continue
|
||||
}
|
||||
|
||||
blobID, err := s.insertBlob(contentBlob.BlobHash, contentBlob.Length, false)
|
||||
blobID, err := s.insertBlob(tx, contentBlob.BlobHash, contentBlob.Length, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
args := []interface{}{streamID, blobID, contentBlob.BlobNum}
|
||||
_, err = s.exec(
|
||||
_, err = s.exec(tx,
|
||||
"INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")",
|
||||
args...,
|
||||
)
|
||||
|
@ -622,6 +623,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
|||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetHashRange gets the smallest and biggest hashes in the db
|
||||
|
@ -694,38 +696,37 @@ func (s *SQL) GetStoredHashesInRange(ctx context.Context, start, end bits.Bitmap
|
|||
}
|
||||
|
||||
// txFunc is a function that can be wrapped in a transaction
|
||||
type txFunc func(tx *sql.Tx) error
|
||||
type txFunc func(tx Transactor) error
|
||||
|
||||
// withTx wraps a function in an sql transaction. the transaction is committed if there's no error, or rolled back if there is one.
|
||||
// if dbOrTx is an sql.DB, a new transaction is started
|
||||
// withTx wraps a function in an sql transaction. the transaction is committed if there's
|
||||
// no error, or rolled back if there is one. if dbOrTx is not a Transactor (e.g. if it's
|
||||
// an *sql.DB), withTx attempts to start a new transaction to use.
|
||||
func withTx(dbOrTx interface{}, f txFunc) (err error) {
|
||||
var tx *sql.Tx
|
||||
var tx Transactor
|
||||
var ok bool
|
||||
|
||||
switch t := dbOrTx.(type) {
|
||||
case *sql.Tx:
|
||||
tx = t
|
||||
case *sql.DB:
|
||||
tx, err = t.Begin()
|
||||
tx, ok = dbOrTx.(Transactor)
|
||||
if !ok {
|
||||
tx, err = Begin(dbOrTx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
if rollBackError := tx.Rollback(); rollBackError != nil {
|
||||
log.Error("failed to rollback tx on panic - ", rollBackError)
|
||||
log.Error("failed to rollback tx on panic: ", rollBackError)
|
||||
}
|
||||
panic(p)
|
||||
err = errors.Prefix("panic", p)
|
||||
} else if err != nil {
|
||||
if rollBackError := tx.Rollback(); rollBackError != nil {
|
||||
log.Error("failed to rollback tx on panic - ", rollBackError)
|
||||
log.Error("failed to rollback tx: ", rollBackError)
|
||||
}
|
||||
} else {
|
||||
err = errors.Err(tx.Commit())
|
||||
}
|
||||
}()
|
||||
default:
|
||||
return errors.Err("db or tx required")
|
||||
}
|
||||
|
||||
return f(tx)
|
||||
}
|
||||
|
@ -739,12 +740,12 @@ func closeRows(rows *sql.Rows) {
|
|||
}
|
||||
}
|
||||
|
||||
func (s *SQL) exec(query string, args ...interface{}) (int64, error) {
|
||||
func (s *SQL) exec(ex Executor, query string, args ...interface{}) (int64, error) {
|
||||
s.logQuery(query, args...)
|
||||
attempt, maxAttempts := 0, 3
|
||||
Retry:
|
||||
attempt++
|
||||
result, err := s.conn.Exec(query, args...)
|
||||
result, err := ex.Exec(query, args...)
|
||||
if isLockTimeoutError(err) {
|
||||
if attempt <= maxAttempts {
|
||||
//Error 1205: Lock wait timeout exceeded; try restarting transaction
|
||||
|
|
45
db/interfaces.go
Normal file
45
db/interfaces.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/lbryio/lbry.go/v2/extras/errors"
|
||||
)
|
||||
|
||||
// Executor can perform SQL queries.
|
||||
type Executor interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// Transactor can commit and rollback, on top of being able to execute queries.
|
||||
type Transactor interface {
|
||||
Commit() error
|
||||
Rollback() error
|
||||
|
||||
Executor
|
||||
}
|
||||
|
||||
// Begin begins a transaction
|
||||
func Begin(db interface{}) (Transactor, error) {
|
||||
type beginner interface {
|
||||
Begin() (Transactor, error)
|
||||
}
|
||||
|
||||
creator, ok := db.(beginner)
|
||||
if ok {
|
||||
return creator.Begin()
|
||||
}
|
||||
|
||||
type sqlBeginner interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
creator2, ok := db.(sqlBeginner)
|
||||
if ok {
|
||||
return creator2.Begin()
|
||||
}
|
||||
|
||||
return nil, errors.Err("database does not support transactions")
|
||||
}
|
|
@ -1,27 +1,39 @@
|
|||
package shared
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBlobTrace_Serialize(t *testing.T) {
|
||||
hostName, err := os.Hostname()
|
||||
require.NoError(t, err)
|
||||
|
||||
stack := NewBlobTrace(10*time.Second, "test")
|
||||
stack.Stack(20*time.Second, "test2")
|
||||
stack.Stack(30*time.Second, "test3")
|
||||
serialized, err := stack.Serialize()
|
||||
assert.NoError(t, err)
|
||||
t.Log(serialized)
|
||||
expected := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}"
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := `{"stacks":[{"timing":10000000000,"origin_name":"test","host_name":"` +
|
||||
hostName +
|
||||
`"},{"timing":20000000000,"origin_name":"test2","host_name":"` +
|
||||
hostName +
|
||||
`"},{"timing":30000000000,"origin_name":"test3","host_name":"` +
|
||||
hostName +
|
||||
`"}]}`
|
||||
assert.Equal(t, expected, serialized)
|
||||
}
|
||||
|
||||
func TestBlobTrace_Deserialize(t *testing.T) {
|
||||
serialized := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}"
|
||||
serialized := `{"stacks":[{"timing":10000000000,"origin_name":"test"},{"timing":20000000000,"origin_name":"test2"},{"timing":30000000000,"origin_name":"test3"}]}`
|
||||
stack, err := Deserialize(serialized)
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Len(t, stack.Stacks, 3)
|
||||
assert.Equal(t, stack.Stacks[0].Timing, 10*time.Second)
|
||||
assert.Equal(t, stack.Stacks[1].Timing, 20*time.Second)
|
||||
|
|
Loading…
Reference in a new issue