Insert in tx #51
3 changed files with 125 additions and 67 deletions
71
db/db.go
71
db/db.go
|
@ -97,7 +97,7 @@ func (s *SQL) AddBlob(hash string, length int, isStored bool) error {
|
||||||
return errors.Err("not connected")
|
return errors.Err("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := s.insertBlob(hash, length, isStored)
|
_, err := s.insertBlob(s.conn, hash, length, isStored)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -163,7 +163,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
|
||||||
//args = append(args, hash, true, stream.MaxBlobSize, dayAgo)
|
//args = append(args, hash, true, stream.MaxBlobSize, dayAgo)
|
||||||
}
|
}
|
||||||
q = strings.TrimSuffix(q, ",")
|
q = strings.TrimSuffix(q, ",")
|
||||||
_, err := s.exec(q)
|
_, err := s.exec(s.conn, q)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -171,7 +171,7 @@ func (s *SQL) insertBlobs(hashes []string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error) {
|
func (s *SQL) insertBlob(ex Executor, hash string, length int, isStored bool) (int64, error) {
|
||||||
if length <= 0 {
|
if length <= 0 {
|
||||||
return 0, errors.Err("length must be positive")
|
return 0, errors.Err("length must be positive")
|
||||||
}
|
}
|
||||||
|
@ -188,13 +188,13 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
|
||||||
q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))"
|
q = "INSERT INTO blob_ (hash, is_stored, length) VALUES (" + qt.Qs(len(args)) + ") ON DUPLICATE KEY UPDATE is_stored = (is_stored or VALUES(is_stored))"
|
||||||
}
|
}
|
||||||
|
|
||||||
blobID, err := s.exec(q, args...)
|
blobID, err := s.exec(ex, q, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if blobID == 0 {
|
if blobID == 0 {
|
||||||
err = s.conn.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID)
|
err = ex.QueryRow("SELECT id FROM blob_ WHERE hash = ?", hash).Scan(&blobID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Err(err)
|
return 0, errors.Err(err)
|
||||||
}
|
}
|
||||||
|
@ -203,7 +203,7 @@ func (s *SQL) insertBlob(hash string, length int, isStored bool) (int64, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.TrackAccess == TrackAccessBlobs {
|
if s.TrackAccess == TrackAccessBlobs {
|
||||||
err := s.touchBlobs([]uint64{uint64(blobID)})
|
err := s.touchBlobs(ex, []uint64{uint64(blobID)})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Err(err)
|
return 0, errors.Err(err)
|
||||||
}
|
}
|
||||||
|
@ -227,7 +227,7 @@ func (s *SQL) insertStream(hash string, sdBlobID int64) (int64, error) {
|
||||||
q = "INSERT IGNORE INTO stream (hash, sd_blob_id) VALUES (" + qt.Qs(len(args)) + ")"
|
q = "INSERT IGNORE INTO stream (hash, sd_blob_id) VALUES (" + qt.Qs(len(args)) + ")"
|
||||||
}
|
}
|
||||||
|
|
||||||
streamID, err := s.exec(q, args...)
|
streamID, err := s.exec(s.conn, q, args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, errors.Err(err)
|
return 0, errors.Err(err)
|
||||||
}
|
}
|
||||||
|
@ -266,16 +266,16 @@ func (s *SQL) HasBlobs(hashes []string, touch bool) (map[string]bool, error) {
|
||||||
|
|
||||||
if touch {
|
if touch {
|
||||||
if s.TrackAccess == TrackAccessBlobs {
|
if s.TrackAccess == TrackAccessBlobs {
|
||||||
s.touchBlobs(idsNeedingTouch)
|
_ = s.touchBlobs(s.conn, idsNeedingTouch)
|
||||||
} else if s.TrackAccess == TrackAccessStreams {
|
} else if s.TrackAccess == TrackAccessStreams {
|
||||||
s.touchStreams(idsNeedingTouch)
|
_ = s.touchStreams(idsNeedingTouch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return exists, err
|
return exists, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQL) touchBlobs(blobIDs []uint64) error {
|
func (s *SQL) touchBlobs(ex Executor, blobIDs []uint64) error {
|
||||||
if len(blobIDs) == 0 {
|
if len(blobIDs) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -288,7 +288,7 @@ func (s *SQL) touchBlobs(blobIDs []uint64) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
_, err := s.exec(query, args...)
|
_, err := s.exec(ex, query, args...)
|
||||||
log.Debugf("touched %d blobs and took %s", len(blobIDs), time.Since(startTime))
|
log.Debugf("touched %d blobs and took %s", len(blobIDs), time.Since(startTime))
|
||||||
return errors.Err(err)
|
return errors.Err(err)
|
||||||
}
|
}
|
||||||
|
@ -306,7 +306,7 @@ func (s *SQL) touchStreams(streamIDs []uint64) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
startTime := time.Now()
|
startTime := time.Now()
|
||||||
_, err := s.exec(query, args...)
|
_, err := s.exec(s.conn, query, args...)
|
||||||
log.Debugf("touched %d streams and took %s", len(streamIDs), time.Since(startTime))
|
log.Debugf("touched %d streams and took %s", len(streamIDs), time.Since(startTime))
|
||||||
return errors.Err(err)
|
return errors.Err(err)
|
||||||
}
|
}
|
||||||
|
@ -406,16 +406,16 @@ WHERE b.is_stored = 1 and b.hash IN (` + qt.Qs(len(batch)) + `)`
|
||||||
// NOTE: If SoftDelete is enabled, streams will never be deleted
|
// NOTE: If SoftDelete is enabled, streams will never be deleted
|
||||||
func (s *SQL) Delete(hash string) error {
|
func (s *SQL) Delete(hash string) error {
|
||||||
if s.SoftDelete {
|
if s.SoftDelete {
|
||||||
_, err := s.exec("UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash)
|
_, err := s.exec(s.conn, "UPDATE blob_ SET is_stored = 0 WHERE hash = ?", hash)
|
||||||
return errors.Err(err)
|
return errors.Err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := s.exec("DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash)
|
_, err := s.exec(s.conn, "DELETE FROM stream WHERE sd_blob_id = (SELECT id FROM blob_ WHERE hash = ?)", hash)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.Err(err)
|
return errors.Err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = s.exec("DELETE FROM blob_ WHERE hash = ?", hash)
|
_, err = s.exec(s.conn, "DELETE FROM blob_ WHERE hash = ?", hash)
|
||||||
return errors.Err(err)
|
return errors.Err(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -590,7 +590,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
||||||
return errors.Err("not connected")
|
return errors.Err("not connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
sdBlobID, err := s.insertBlob(sdHash, sdBlobLength, true)
|
sdBlobID, err := s.insertBlob(s.conn, sdHash, sdBlobLength, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -600,6 +600,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return withTx(s.conn, func(tx Transactor) error {
|
||||||
// insert content blobs and connect them to stream
|
// insert content blobs and connect them to stream
|
||||||
for _, contentBlob := range sdBlob.Blobs {
|
for _, contentBlob := range sdBlob.Blobs {
|
||||||
if contentBlob.BlobHash == "" {
|
if contentBlob.BlobHash == "" {
|
||||||
|
@ -607,13 +608,13 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
blobID, err := s.insertBlob(contentBlob.BlobHash, contentBlob.Length, false)
|
blobID, err := s.insertBlob(tx, contentBlob.BlobHash, contentBlob.Length, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
args := []interface{}{streamID, blobID, contentBlob.BlobNum}
|
args := []interface{}{streamID, blobID, contentBlob.BlobNum}
|
||||||
_, err = s.exec(
|
_, err = s.exec(tx,
|
||||||
"INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")",
|
"INSERT IGNORE INTO stream_blob (stream_id, blob_id, num) VALUES ("+qt.Qs(len(args))+")",
|
||||||
args...,
|
args...,
|
||||||
)
|
)
|
||||||
|
@ -622,6 +623,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob SdBlob) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetHashRange gets the smallest and biggest hashes in the db
|
// GetHashRange gets the smallest and biggest hashes in the db
|
||||||
|
@ -694,38 +696,37 @@ func (s *SQL) GetStoredHashesInRange(ctx context.Context, start, end bits.Bitmap
|
||||||
}
|
}
|
||||||
|
|
||||||
// txFunc is a function that can be wrapped in a transaction
|
// txFunc is a function that can be wrapped in a transaction
|
||||||
type txFunc func(tx *sql.Tx) error
|
type txFunc func(tx Transactor) error
|
||||||
|
|
||||||
// withTx wraps a function in an sql transaction. the transaction is committed if there's no error, or rolled back if there is one.
|
// withTx wraps a function in an sql transaction. the transaction is committed if there's
|
||||||
// if dbOrTx is an sql.DB, a new transaction is started
|
// no error, or rolled back if there is one. if dbOrTx is not a Transactor (e.g. if it's
|
||||||
|
// an *sql.DB), withTx attempts to start a new transaction to use.
|
||||||
func withTx(dbOrTx interface{}, f txFunc) (err error) {
|
func withTx(dbOrTx interface{}, f txFunc) (err error) {
|
||||||
var tx *sql.Tx
|
var tx Transactor
|
||||||
|
var ok bool
|
||||||
|
|
||||||
switch t := dbOrTx.(type) {
|
tx, ok = dbOrTx.(Transactor)
|
||||||
case *sql.Tx:
|
if !ok {
|
||||||
tx = t
|
tx, err = Begin(dbOrTx)
|
||||||
case *sql.DB:
|
|
||||||
tx, err = t.Begin()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if p := recover(); p != nil {
|
if p := recover(); p != nil {
|
||||||
if rollBackError := tx.Rollback(); rollBackError != nil {
|
if rollBackError := tx.Rollback(); rollBackError != nil {
|
||||||
log.Error("failed to rollback tx on panic - ", rollBackError)
|
log.Error("failed to rollback tx on panic: ", rollBackError)
|
||||||
}
|
}
|
||||||
panic(p)
|
err = errors.Prefix("panic", p)
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
if rollBackError := tx.Rollback(); rollBackError != nil {
|
if rollBackError := tx.Rollback(); rollBackError != nil {
|
||||||
log.Error("failed to rollback tx on panic - ", rollBackError)
|
log.Error("failed to rollback tx: ", rollBackError)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
err = errors.Err(tx.Commit())
|
err = errors.Err(tx.Commit())
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
default:
|
|
||||||
return errors.Err("db or tx required")
|
|
||||||
}
|
|
||||||
|
|
||||||
return f(tx)
|
return f(tx)
|
||||||
}
|
}
|
||||||
|
@ -739,12 +740,12 @@ func closeRows(rows *sql.Rows) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *SQL) exec(query string, args ...interface{}) (int64, error) {
|
func (s *SQL) exec(ex Executor, query string, args ...interface{}) (int64, error) {
|
||||||
s.logQuery(query, args...)
|
s.logQuery(query, args...)
|
||||||
attempt, maxAttempts := 0, 3
|
attempt, maxAttempts := 0, 3
|
||||||
Retry:
|
Retry:
|
||||||
attempt++
|
attempt++
|
||||||
result, err := s.conn.Exec(query, args...)
|
result, err := ex.Exec(query, args...)
|
||||||
if isLockTimeoutError(err) {
|
if isLockTimeoutError(err) {
|
||||||
if attempt <= maxAttempts {
|
if attempt <= maxAttempts {
|
||||||
//Error 1205: Lock wait timeout exceeded; try restarting transaction
|
//Error 1205: Lock wait timeout exceeded; try restarting transaction
|
||||||
|
|
45
db/interfaces.go
Normal file
45
db/interfaces.go
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
package db
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"github.com/lbryio/lbry.go/v2/extras/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Executor can perform SQL queries.
|
||||||
|
type Executor interface {
|
||||||
|
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||||
|
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||||
|
QueryRow(query string, args ...interface{}) *sql.Row
|
||||||
|
}
|
||||||
|
|
||||||
|
// Transactor can commit and rollback, on top of being able to execute queries.
|
||||||
|
type Transactor interface {
|
||||||
|
Commit() error
|
||||||
|
Rollback() error
|
||||||
|
|
||||||
|
Executor
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin begins a transaction
|
||||||
|
func Begin(db interface{}) (Transactor, error) {
|
||||||
|
type beginner interface {
|
||||||
|
Begin() (Transactor, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
creator, ok := db.(beginner)
|
||||||
|
if ok {
|
||||||
|
return creator.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
|
type sqlBeginner interface {
|
||||||
|
Begin() (*sql.Tx, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
creator2, ok := db.(sqlBeginner)
|
||||||
|
if ok {
|
||||||
|
return creator2.Begin()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, errors.Err("database does not support transactions")
|
||||||
|
}
|
|
@ -1,27 +1,39 @@
|
||||||
package shared
|
package shared
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"os"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestBlobTrace_Serialize(t *testing.T) {
|
func TestBlobTrace_Serialize(t *testing.T) {
|
||||||
|
hostName, err := os.Hostname()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
stack := NewBlobTrace(10*time.Second, "test")
|
stack := NewBlobTrace(10*time.Second, "test")
|
||||||
stack.Stack(20*time.Second, "test2")
|
stack.Stack(20*time.Second, "test2")
|
||||||
stack.Stack(30*time.Second, "test3")
|
stack.Stack(30*time.Second, "test3")
|
||||||
serialized, err := stack.Serialize()
|
serialized, err := stack.Serialize()
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
t.Log(serialized)
|
|
||||||
expected := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}"
|
expected := `{"stacks":[{"timing":10000000000,"origin_name":"test","host_name":"` +
|
||||||
|
hostName +
|
||||||
|
`"},{"timing":20000000000,"origin_name":"test2","host_name":"` +
|
||||||
|
hostName +
|
||||||
|
`"},{"timing":30000000000,"origin_name":"test3","host_name":"` +
|
||||||
|
hostName +
|
||||||
|
`"}]}`
|
||||||
assert.Equal(t, expected, serialized)
|
assert.Equal(t, expected, serialized)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestBlobTrace_Deserialize(t *testing.T) {
|
func TestBlobTrace_Deserialize(t *testing.T) {
|
||||||
serialized := "{\"stacks\":[{\"timing\":10000000000,\"origin_name\":\"test\"},{\"timing\":20000000000,\"origin_name\":\"test2\"},{\"timing\":30000000000,\"origin_name\":\"test3\"}]}"
|
serialized := `{"stacks":[{"timing":10000000000,"origin_name":"test"},{"timing":20000000000,"origin_name":"test2"},{"timing":30000000000,"origin_name":"test3"}]}`
|
||||||
stack, err := Deserialize(serialized)
|
stack, err := Deserialize(serialized)
|
||||||
assert.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
assert.Len(t, stack.Stacks, 3)
|
assert.Len(t, stack.Stacks, 3)
|
||||||
assert.Equal(t, stack.Stacks[0].Timing, 10*time.Second)
|
assert.Equal(t, stack.Stacks[0].Timing, 10*time.Second)
|
||||||
assert.Equal(t, stack.Stacks[1].Timing, 20*time.Second)
|
assert.Equal(t, stack.Stacks[1].Timing, 20*time.Second)
|
||||||
|
|
Loading…
Reference in a new issue