From 0aee55d2494b20d6bc96b7e785267a549ce22260 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Thu, 1 Mar 2018 16:12:53 -0500 Subject: [PATCH] add sql txn where needed. closes lbryio/reflector-cluster#58 --- cluster/cluster.go | 6 +++ cmd/cluster.go | 5 +- db/db.go | 115 ++++++++++++++++++++++++++++++--------------- 3 files changed, 85 insertions(+), 41 deletions(-) diff --git a/cluster/cluster.go b/cluster/cluster.go index c717441..9b8faf0 100644 --- a/cluster/cluster.go +++ b/cluster/cluster.go @@ -1,6 +1,9 @@ package cluster import ( + "io/ioutil" + baselog "log" + "github.com/lbryio/lbry.go/errors" "github.com/hashicorp/serf/serf" @@ -13,6 +16,9 @@ func Connect(nodeName, addr string, port int) (*serf.Serf, <-chan serf.Event, er conf.MemberlistConfig.AdvertisePort = port conf.NodeName = nodeName + nullLogger := baselog.New(ioutil.Discard, "", 0) + conf.Logger = nullLogger + eventCh := make(chan serf.Event) conf.EventCh = eventCh diff --git a/cmd/cluster.go b/cmd/cluster.go index 5fa4ccd..1c8bd4b 100644 --- a/cmd/cluster.go +++ b/cmd/cluster.go @@ -69,8 +69,9 @@ func clusterCmd(cmd *cobra.Command, args []string) { if event.EventType() == serf.EventMemberJoin && len(memberEvent.Members) == 1 && memberEvent.Members[0].Name == nodeName { // ignore event from my own joining of the cluster } else { - spew.Dump(c.Members()) - log.Printf("my hash range is now %d\n", getHashRangeStart(nodeName, getAliveMembers(c.Members()))) + //spew.Dump(c.Members()) + alive := getAliveMembers(c.Members()) + log.Printf("%s: my hash range is now %d of %d\n", nodeName, getHashRangeStart(nodeName, alive), len(alive)) // figure out my new hash range based on the start and the number of alive members // get hashes in that range that need announcing // announce them diff --git a/db/db.go b/db/db.go index 01b0e5d..cf56862 100644 --- a/db/db.go +++ b/db/db.go @@ -3,10 +3,9 @@ package db import ( "database/sql" - "github.com/lbryio/reflector.go/types" - "github.com/lbryio/lbry.go/errors" qtools "github.com/lbryio/query.go" + "github.com/lbryio/reflector.go/types" _ "github.com/go-sql-driver/mysql" log "github.com/sirupsen/logrus" @@ -48,6 +47,12 @@ func (s *SQL) AddBlob(hash string, length int, stored bool) error { return errors.Err("not connected") } + return withTx(s.conn, func(tx *sql.Tx) error { + return addBlob(tx, hash, length, stored) + }) +} + +func addBlob(tx *sql.Tx, hash string, length int, stored bool) error { if length <= 0 { return errors.Err("length must be positive") } @@ -57,7 +62,7 @@ func (s *SQL) AddBlob(hash string, length int, stored bool) error { logQuery(query, args...) - stmt, err := s.conn.Prepare(query) + stmt, err := tx.Prepare(query) if err != nil { return errors.Err(err) } @@ -93,48 +98,20 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er return errors.Err("not connected") } - // TODO: should do all this in transaction - - // insert sd blob - err := s.AddBlob(sdHash, sdBlobLength, true) - if err != nil { - return err - } - - // insert stream - query := "INSERT IGNORE INTO stream (hash, sd_hash) VALUES (?,?)" - args := []interface{}{sdBlob.StreamHash, sdHash} - - logQuery(query, args...) - - stmt, err := s.conn.Prepare(query) - if err != nil { - return errors.Err(err) - } - - _, err = stmt.Exec(args...) - if err != nil { - return errors.Err(err) - } - - // insert content blobs and connect them to stream - for _, contentBlob := range sdBlob.Blobs { - if contentBlob.BlobHash == "" { - // null terminator blob - continue - } - - err := s.AddBlob(contentBlob.BlobHash, contentBlob.Length, false) + return withTx(s.conn, func(tx *sql.Tx) error { + // insert sd blob + err := addBlob(tx, sdHash, sdBlobLength, true) if err != nil { return err } - query := "INSERT IGNORE INTO stream_blob (stream_hash, blob_hash, num) VALUES (?,?,?)" - args := []interface{}{sdBlob.StreamHash, contentBlob.BlobHash, contentBlob.BlobNum} + // insert stream + query := "INSERT IGNORE INTO stream (hash, sd_hash) VALUES (?,?)" + args := []interface{}{sdBlob.StreamHash, sdHash} logQuery(query, args...) - stmt, err := s.conn.Prepare(query) + stmt, err := tx.Prepare(query) if err != nil { return errors.Err(err) } @@ -143,9 +120,69 @@ func (s *SQL) AddSDBlob(sdHash string, sdBlobLength int, sdBlob types.SdBlob) er if err != nil { return errors.Err(err) } + + // insert content blobs and connect them to stream + for _, contentBlob := range sdBlob.Blobs { + if contentBlob.BlobHash == "" { + // null terminator blob + continue + } + + err := addBlob(tx, contentBlob.BlobHash, contentBlob.Length, false) + if err != nil { + return err + } + + query := "INSERT IGNORE INTO stream_blob (stream_hash, blob_hash, num) VALUES (?,?,?)" + args := []interface{}{sdBlob.StreamHash, contentBlob.BlobHash, contentBlob.BlobNum} + + logQuery(query, args...) + + stmt, err := tx.Prepare(query) + if err != nil { + return errors.Err(err) + } + + _, err = stmt.Exec(args...) + if err != nil { + return errors.Err(err) + } + } + return nil + }) +} + +// txFunc is a function that can be wrapped in a transaction +type txFunc func(tx *sql.Tx) error + +// withTx wraps a function in an sql transaction. the transaction is committed if there's no error, or rolled back if there is one. +// if dbOrTx is an sql.DB, a new transaction is started +func withTx(dbOrTx interface{}, f txFunc) (err error) { + var tx *sql.Tx + + switch t := dbOrTx.(type) { + case *sql.Tx: + tx = t + case *sql.DB: + tx, err = t.Begin() + if err != nil { + return err + } + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) + } else if err != nil { + tx.Rollback() + } else { + err = errors.Err(tx.Commit()) + } + }() + default: + return errors.Err("db or tx required") } - return nil + return f(tx) } func schema() {