add sql txn where needed. closes lbryio/reflector-cluster#58

This commit is contained in:
Alex Grintsvayg 2018-03-01 16:12:53 -05:00
parent 32a27c4e4d
commit 0aee55d249
3 changed files with 85 additions and 41 deletions

View file

@ -1,6 +1,9 @@
package cluster package cluster
import ( import (
"io/ioutil"
baselog "log"
"github.com/lbryio/lbry.go/errors" "github.com/lbryio/lbry.go/errors"
"github.com/hashicorp/serf/serf" "github.com/hashicorp/serf/serf"
@ -13,6 +16,9 @@ func Connect(nodeName, addr string, port int) (*serf.Serf, <-chan serf.Event, er
conf.MemberlistConfig.AdvertisePort = port conf.MemberlistConfig.AdvertisePort = port
conf.NodeName = nodeName conf.NodeName = nodeName
nullLogger := baselog.New(ioutil.Discard, "", 0)
conf.Logger = nullLogger
eventCh := make(chan serf.Event) eventCh := make(chan serf.Event)
conf.EventCh = eventCh conf.EventCh = eventCh

View file

@ -69,8 +69,9 @@ func clusterCmd(cmd *cobra.Command, args []string) {
if event.EventType() == serf.EventMemberJoin && len(memberEvent.Members) == 1 && memberEvent.Members[0].Name == nodeName { if event.EventType() == serf.EventMemberJoin && len(memberEvent.Members) == 1 && memberEvent.Members[0].Name == nodeName {
// ignore event from my own joining of the cluster // ignore event from my own joining of the cluster
} else { } else {
spew.Dump(c.Members()) //spew.Dump(c.Members())
log.Printf("my hash range is now %d\n", getHashRangeStart(nodeName, getAliveMembers(c.Members()))) alive := getAliveMembers(c.Members())
log.Printf("%s: my hash range is now %d of %d\n", nodeName, getHashRangeStart(nodeName, alive), len(alive))
// figure out my new hash range based on the start and the number of alive members // figure out my new hash range based on the start and the number of alive members
// get hashes in that range that need announcing // get hashes in that range that need announcing
// announce them // announce them

View file

@ -3,10 +3,9 @@ package db
import ( import (
"database/sql" "database/sql"
"github.com/lbryio/reflector.go/types"
"github.com/lbryio/lbry.go/errors" "github.com/lbryio/lbry.go/errors"
qtools "github.com/lbryio/query.go" qtools "github.com/lbryio/query.go"
"github.com/lbryio/reflector.go/types"
_ "github.com/go-sql-driver/mysql" _ "github.com/go-sql-driver/mysql"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@ -48,6 +47,12 @@ func (s *SQL) AddBlob(hash string, length int, stored bool) error {
return errors.Err("not connected") return errors.Err("not connected")
} }
return withTx(s.conn, func(tx *sql.Tx) error {
return addBlob(tx, hash, length, stored)
})
}
func addBlob(tx *sql.Tx, hash string, length int, stored bool) error {
if length <= 0 { if length <= 0 {
return errors.Err("length must be positive") return errors.Err("length must be positive")
} }
@ -57,7 +62,7 @@ func (s *SQL) AddBlob(hash string, length int, stored bool) error {
logQuery(query, args...) logQuery(query, args...)
stmt, err := s.conn.Prepare(query) stmt, err := tx.Prepare(query)
if err != nil { if err != nil {
return errors.Err(err) return errors.Err(err)
} }
@ -93,10 +98,9 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er
return errors.Err("not connected") return errors.Err("not connected")
} }
// TODO: should do all this in transaction return withTx(s.conn, func(tx *sql.Tx) error {
// insert sd blob // insert sd blob
err := s.AddBlob(sdHash, sdBlobLength, true) err := addBlob(tx, sdHash, sdBlobLength, true)
if err != nil { if err != nil {
return err return err
} }
@ -107,7 +111,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er
logQuery(query, args...) logQuery(query, args...)
stmt, err := s.conn.Prepare(query) stmt, err := tx.Prepare(query)
if err != nil { if err != nil {
return errors.Err(err) return errors.Err(err)
} }
@ -124,7 +128,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er
continue continue
} }
err := s.AddBlob(contentBlob.BlobHash, contentBlob.Length, false) err := addBlob(tx, contentBlob.BlobHash, contentBlob.Length, false)
if err != nil { if err != nil {
return err return err
} }
@ -134,7 +138,7 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er
logQuery(query, args...) logQuery(query, args...)
stmt, err := s.conn.Prepare(query) stmt, err := tx.Prepare(query)
if err != nil { if err != nil {
return errors.Err(err) return errors.Err(err)
} }
@ -144,8 +148,41 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er
return errors.Err(err) return errors.Err(err)
} }
} }
return nil return nil
})
}
// txFunc is a function that can be wrapped in a transaction
type txFunc func(tx *sql.Tx) error
// withTx wraps a function in an sql transaction. the transaction is committed if there's no error, or rolled back if there is one.
// if dbOrTx is an sql.DB, a new transaction is started
func withTx(dbOrTx interface{}, f txFunc) (err error) {
var tx *sql.Tx
switch t := dbOrTx.(type) {
case *sql.Tx:
tx = t
case *sql.DB:
tx, err = t.Begin()
if err != nil {
return err
}
defer func() {
if p := recover(); p != nil {
tx.Rollback()
panic(p)
} else if err != nil {
tx.Rollback()
} else {
err = errors.Err(tx.Commit())
}
}()
default:
return errors.Err("db or tx required")
}
return f(tx)
} }
func schema() { func schema() {