From 99a3a1d091d13401cd635e9b376ffec8f72ef156 Mon Sep 17 00:00:00 2001 From: Alex Grintsvayg Date: Tue, 1 Aug 2017 13:00:14 -0400 Subject: [PATCH] make merge compatible with an existing transaction --- templates/23_merge.tpl | 47 ++++++++++++++++++++++++++---------------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/templates/23_merge.tpl b/templates/23_merge.tpl index 6b5cdd2..d84ed28 100644 --- a/templates/23_merge.tpl +++ b/templates/23_merge.tpl @@ -4,32 +4,42 @@ {{- else -}} {{- $dot := . }} // Merge combines two {{$tableNamePlural}} into one. The primary record will be kept, and the secondary will be deleted. -func Merge{{$tableNamePlural}}(exec boil.Executor, primaryID uint64, secondaryID uint64) error { - txdb, ok := exec.(boil.Beginner) +func Merge{{$tableNamePlural}}(exec boil.Executor, primaryID uint64, secondaryID uint64) (err error) { + tx, ok := exec.(boil.Transactor) if !ok { - return errors.New("database does not support transactions") - } + txdb, ok := exec.(boil.Beginner) + if !ok { + return errors.New("database does not support transactions") + } - tx, txErr := txdb.Begin() - if txErr != nil { - return txErr + tx, err = txdb.Begin() + if err != nil { + return err + } + + defer func() { + if p := recover(); p != nil { + tx.Rollback() + panic(p) // Rollback, then propagate panic + } else if err != nil { + tx.Rollback() + } else { + err = tx.Commit() + } + }() } primary, err := Find{{$tableNameSingular}}(tx, primaryID) if err != nil { - tx.Rollback() return err - } - if primary == nil { + } else if primary == nil { return errors.New("Primary {{$tableNameSingular}} not found") } secondary, err := Find{{$tableNameSingular}}(tx, secondaryID) if err != nil { - tx.Rollback() return err - } - if secondary == nil { + } else if secondary == nil { return errors.New("Secondary {{$tableNameSingular}} not found") } @@ -59,9 +69,13 @@ func Merge{{$tableNamePlural}}(exec boil.Executor, primaryID uint64, secondaryID {{- end }} } - err = mergeModels(tx, primaryID, secondaryID, foreignKeys, conflictingKeys) + sqlTx, ok := tx.(*sql.Tx) + if !ok { + return errors.New("tx must be an sql.Tx") + } + + err = mergeModels(sqlTx, primaryID, secondaryID, foreignKeys, conflictingKeys) if err != nil { - tx.Rollback() return err } @@ -80,17 +94,14 @@ func Merge{{$tableNamePlural}}(exec boil.Executor, primaryID uint64, secondaryID err = primary.Update(tx) if err != nil { - tx.Rollback() return errors.WithStack(err) } err = secondary.Delete(tx) if err != nil { - tx.Rollback() return errors.WithStack(err) } - tx.Commit() return nil }