diff --git a/queries/query_builders.go b/queries/query_builders.go index 3136693..7cf753f 100644 --- a/queries/query_builders.go +++ b/queries/query_builders.go @@ -283,6 +283,49 @@ func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bo return buf.String() } +// BuildUpsertQueryMSSQL builds a SQL statement string using the upsertData provided. +func BuildUpsertQueryMSSQL(dia Dialect, tableName string, primary, update, insert []string, output []string) string { + insert = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, insert) + + buf := strmangle.GetBuffer() + defer strmangle.PutBuffer(buf) + + startIndex := 1 + + fmt.Fprintf(buf, "MERGE INTO %s as [t]\n", tableName) + fmt.Fprintf(buf, "USING (SELECT %s) as [s] ([%s])\n", + strmangle.Placeholders(dia.IndexPlaceholders, len(primary), startIndex, 1), + strings.Join(primary, string(dia.RQ)+","+string(dia.LQ))) + fmt.Fprint(buf, "ON (") + for i, v := range primary { + if i != 0 { + fmt.Fprint(buf, " AND ") + } + fmt.Fprintf(buf, "[s].[%s] = [t].[%s]", v, v) + } + fmt.Fprint(buf, ")\n") + + startIndex = len(primary) + 1 + + fmt.Fprint(buf, "WHEN MATCHED THEN ") + fmt.Fprintf(buf, "UPDATE SET %s\n", strmangle.SetParamNames(string(dia.LQ), string(dia.RQ), startIndex, update)) + + startIndex = len(primary) + 1 + len(update) + + fmt.Fprint(buf, "WHEN NOT MATCHED THEN ") + fmt.Fprintf(buf, "INSERT (%s) VALUES (%s)", + strings.Join(insert, ", "), + strmangle.Placeholders(dia.IndexPlaceholders, len(insert), startIndex, 1)) + + if len(output) > 0 { + fmt.Fprintf(buf, "\nOUTPUT INSERTED.[%s];", strings.Join(output, "],INSERTED.[")) + } else { + fmt.Fprint(buf, ";") + } + + return buf.String() +} + func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) { if len(q.groupBy) != 0 { fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", ")) diff --git a/templates/01_types.tpl b/templates/01_types.tpl index bcf0a94..1f8d015 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -35,10 +35,8 @@ var ( {{$varNameSingular}}InsertCache = make(map[string]insertCache) {{$varNameSingular}}UpdateCacheMut sync.RWMutex {{$varNameSingular}}UpdateCache = make(map[string]updateCache) - {{if ne .DriverName "mssql"}} {{$varNameSingular}}UpsertCacheMut sync.RWMutex {{$varNameSingular}}UpsertCache = make(map[string]insertCache) - {{end}} ) var ( diff --git a/templates/02_hooks.tpl b/templates/02_hooks.tpl index 073f8ec..9815639 100644 --- a/templates/02_hooks.tpl +++ b/templates/02_hooks.tpl @@ -4,17 +4,13 @@ var {{$varNameSingular}}BeforeInsertHooks []{{$tableNameSingular}}Hook var {{$varNameSingular}}BeforeUpdateHooks []{{$tableNameSingular}}Hook var {{$varNameSingular}}BeforeDeleteHooks []{{$tableNameSingular}}Hook -{{if ne .DriverName "mssql" -}} var {{$varNameSingular}}BeforeUpsertHooks []{{$tableNameSingular}}Hook -{{- end}} var {{$varNameSingular}}AfterInsertHooks []{{$tableNameSingular}}Hook var {{$varNameSingular}}AfterSelectHooks []{{$tableNameSingular}}Hook var {{$varNameSingular}}AfterUpdateHooks []{{$tableNameSingular}}Hook var {{$varNameSingular}}AfterDeleteHooks []{{$tableNameSingular}}Hook -{{if ne .DriverName "mssql" -}} var {{$varNameSingular}}AfterUpsertHooks []{{$tableNameSingular}}Hook -{{- end}} // doBeforeInsertHooks executes all "before insert" hooks. func (o *{{$tableNameSingular}}) doBeforeInsertHooks(exec boil.Executor) (err error) { @@ -49,7 +45,6 @@ func (o *{{$tableNameSingular}}) doBeforeDeleteHooks(exec boil.Executor) (err er return nil } -{{- if ne .DriverName "mssql" -}} // doBeforeUpsertHooks executes all "before Upsert" hooks. func (o *{{$tableNameSingular}}) doBeforeUpsertHooks(exec boil.Executor) (err error) { for _, hook := range {{$varNameSingular}}BeforeUpsertHooks { @@ -60,7 +55,6 @@ func (o *{{$tableNameSingular}}) doBeforeUpsertHooks(exec boil.Executor) (err er return nil } -{{- end}} // doAfterInsertHooks executes all "after Insert" hooks. func (o *{{$tableNameSingular}}) doAfterInsertHooks(exec boil.Executor) (err error) { @@ -106,7 +100,6 @@ func (o *{{$tableNameSingular}}) doAfterDeleteHooks(exec boil.Executor) (err err return nil } -{{- if ne .DriverName "mssql" -}} // doAfterUpsertHooks executes all "after Upsert" hooks. func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err error) { for _, hook := range {{$varNameSingular}}AfterUpsertHooks { @@ -117,7 +110,6 @@ func (o *{{$tableNameSingular}}) doAfterUpsertHooks(exec boil.Executor) (err err return nil } -{{- end}} // Add{{$tableNameSingular}}Hook registers your hook function for all future operations. func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular}}Hook {{$tableNameSingular}}Hook) { @@ -128,10 +120,8 @@ func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular} {{$varNameSingular}}BeforeUpdateHooks = append({{$varNameSingular}}BeforeUpdateHooks, {{$varNameSingular}}Hook) case boil.BeforeDeleteHook: {{$varNameSingular}}BeforeDeleteHooks = append({{$varNameSingular}}BeforeDeleteHooks, {{$varNameSingular}}Hook) - {{if ne .DriverName "mssql" -}} case boil.BeforeUpsertHook: {{$varNameSingular}}BeforeUpsertHooks = append({{$varNameSingular}}BeforeUpsertHooks, {{$varNameSingular}}Hook) - {{- end}} case boil.AfterInsertHook: {{$varNameSingular}}AfterInsertHooks = append({{$varNameSingular}}AfterInsertHooks, {{$varNameSingular}}Hook) case boil.AfterSelectHook: @@ -140,10 +130,8 @@ func Add{{$tableNameSingular}}Hook(hookPoint boil.HookPoint, {{$varNameSingular} {{$varNameSingular}}AfterUpdateHooks = append({{$varNameSingular}}AfterUpdateHooks, {{$varNameSingular}}Hook) case boil.AfterDeleteHook: {{$varNameSingular}}AfterDeleteHooks = append({{$varNameSingular}}AfterDeleteHooks, {{$varNameSingular}}Hook) - {{if ne .DriverName "mssql" -}} case boil.AfterUpsertHook: {{$varNameSingular}}AfterUpsertHooks = append({{$varNameSingular}}AfterUpsertHooks, {{$varNameSingular}}Hook) - {{- end}} } } {{- end}} diff --git a/templates/17_upsert.tpl b/templates/17_upsert.tpl index 0aeb479..70df4c7 100644 --- a/templates/17_upsert.tpl +++ b/templates/17_upsert.tpl @@ -1,29 +1,28 @@ -{{- if ne .DriverName "mssql" -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} {{- $schemaTable := .Table.Name | .SchemaTable}} // UpsertG attempts an insert, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) UpsertG({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { - return o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...) +func (o *{{$tableNameSingular}}) UpsertG({{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { + return o.Upsert(boil.GetDB(), {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...) } // UpsertGP attempts an insert, and does an update or ignore on conflict. Panics on error. -func (o *{{$tableNameSingular}}) UpsertGP({{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { - if err := o.Upsert(boil.GetDB(), {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { +func (o *{{$tableNameSingular}}) UpsertGP({{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(boil.GetDB(), {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { panic(boil.WrapErr(err)) } } // UpsertP attempts an insert using an executor, and does an update or ignore on conflict. // UpsertP panics on error. -func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { - if err := o.Upsert(exec, {{if ne .DriverName "mysql"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { +func (o *{{$tableNameSingular}}) UpsertP(exec boil.Executor, {{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) { + if err := o.Upsert(exec, {{if eq .DriverName "postgres"}}updateOnConflict, conflictColumns, {{end}}updateColumns, whitelist...); err != nil { panic(boil.WrapErr(err)) } } // Upsert attempts an insert using an executor, and does an update or ignore on conflict. -func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName "mysql"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { +func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if eq .DriverName "postgres"}}updateOnConflict bool, conflictColumns []string, {{end}}updateColumns []string, whitelist ...string) error { if o == nil { return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for upsert") } @@ -40,7 +39,7 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName // Build cache key in-line uglily - mysql vs postgres problems buf := strmangle.GetBuffer() - {{if ne .DriverName "mysql" -}} + {{if eq .DriverName "postgres"}} if updateOnConflict { buf.WriteByte('t') } else { @@ -73,36 +72,64 @@ func (o *{{$tableNameSingular}}) Upsert(exec boil.Executor, {{if ne .DriverName var err error if !cached { - var ret []string - whitelist, ret = strmangle.InsertColumnSet( + insert, ret := strmangle.InsertColumnSet( {{$varNameSingular}}Columns, {{$varNameSingular}}ColumnsWithDefault, {{$varNameSingular}}ColumnsWithoutDefault, nzDefaults, whitelist, ) + {{- if eq .DriverName "mssql"}} + insert = strmangle.SetComplement(insert, {{$varNameSingular}}ColumnsWithAuto) + for i, v := range insert { + if strmangle.ContainsAny({{$varNameSingular}}PrimaryKeyColumns, v) && strmangle.ContainsAny({{$varNameSingular}}ColumnsWithDefault, v) { + insert = append(insert[:i], insert[i+1:]...) + } + } + if len(insert) == 0 { + return errors.New("models: unable to upsert {{.Table.Name}}, could not build insert column list") + } + + ret = strmangle.SetMerge(ret, {{$varNameSingular}}ColumnsWithAuto) + ret = strmangle.SetMerge(ret, {{$varNameSingular}}ColumnsWithDefault) + + {{end -}} update := strmangle.UpdateColumnSet( {{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, updateColumns, ) + + {{- if eq .DriverName "mssql"}} + update = strmangle.SetComplement(update, {{$varNameSingular}}ColumnsWithAuto) + {{end -}} + if len(update) == 0 { return errors.New("{{.PkgName}}: unable to upsert {{.Table.Name}}, could not build update column list") } - {{if ne .DriverName "mysql" -}} + {{if eq .DriverName "postgres"}} conflict := conflictColumns if len(conflict) == 0 { conflict = make([]string, len({{$varNameSingular}}PrimaryKeyColumns)) copy(conflict, {{$varNameSingular}}PrimaryKeyColumns) } - cache.query = queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, whitelist) - {{- else -}} - cache.query = queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, whitelist) + cache.query = queries.BuildUpsertQueryPostgres(dialect, "{{$schemaTable}}", updateOnConflict, ret, update, conflict, insert) + {{- end -}} + + {{if eq .DriverName "mysql"}} + cache.query = queries.BuildUpsertQueryMySQL(dialect, "{{.Table.Name}}", update, insert) cache.retQuery = fmt.Sprintf( "SELECT %s FROM {{.LQ}}{{.Table.Name}}{{.RQ}} WHERE {{whereClause .LQ .RQ 0 .Table.PKey.Columns}}", strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, ret), ","), ) + {{- end -}} + + {{if eq .DriverName "mssql"}} + cache.query = queries.BuildUpsertQueryMSSQL(dialect, "{{.Table.Name}}", {{$varNameSingular}}PrimaryKeyColumns, update, insert, ret) + + whitelist = append({{$varNameSingular}}PrimaryKeyColumns, update...) + whitelist = append(whitelist, insert...) {{- end}} cache.valueMapping, err = queries.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, whitelist) @@ -207,5 +234,4 @@ CacheNoHooks: {{- else -}} return nil {{- end}} -} -{{- end}} +} \ No newline at end of file diff --git a/templates_test/hooks.tpl b/templates_test/hooks.tpl index 412ab0c..2cdca44 100644 --- a/templates_test/hooks.tpl +++ b/templates_test/hooks.tpl @@ -38,7 +38,6 @@ func {{$varNameSingular}}AfterDeleteHook(e boil.Executor, o *{{$tableNameSingula return nil } -{{if ne .DriverName "mssql" -}} func {{$varNameSingular}}BeforeUpsertHook(e boil.Executor, o *{{$tableNameSingular}}) error { *o = {{$tableNameSingular}}{} return nil @@ -48,7 +47,6 @@ func {{$varNameSingular}}AfterUpsertHook(e boil.Executor, o *{{$tableNameSingula *o = {{$tableNameSingular}}{} return nil } -{{- end}} func test{{$tableNamePlural}}Hooks(t *testing.T) { t.Parallel() @@ -126,7 +124,6 @@ func test{{$tableNamePlural}}Hooks(t *testing.T) { } {{$varNameSingular}}AfterDeleteHooks = []{{$tableNameSingular}}Hook{} - {{if ne .DriverName "mssql" -}} Add{{$tableNameSingular}}Hook(boil.BeforeUpsertHook, {{$varNameSingular}}BeforeUpsertHook) if err = o.doBeforeUpsertHooks(nil); err != nil { t.Errorf("Unable to execute doBeforeUpsertHooks: %s", err) @@ -144,6 +141,5 @@ func test{{$tableNamePlural}}Hooks(t *testing.T) { t.Errorf("Expected AfterUpsertHook function to empty object, but got: %#v", o) } {{$varNameSingular}}AfterUpsertHooks = []{{$tableNameSingular}}Hook{} - {{- end}} } {{- end}} diff --git a/templates_test/singleton/boil_suites_test.tpl b/templates_test/singleton/boil_suites_test.tpl index 022d7be..a2ca510 100644 --- a/templates_test/singleton/boil_suites_test.tpl +++ b/templates_test/singleton/boil_suites_test.tpl @@ -327,7 +327,7 @@ func TestSliceUpdateAll(t *testing.T) { {{end -}} {{- end -}} } -{{if ne .DriverName "mssql" -}} + func TestUpsert(t *testing.T) { {{- range $index, $table := .Tables}} {{- if $table.IsJoinTable -}} @@ -336,5 +336,4 @@ func TestUpsert(t *testing.T) { t.Run("{{$tableName}}", test{{$tableName}}Upsert) {{end -}} {{- end -}} -} -{{- end -}} \ No newline at end of file +} \ No newline at end of file diff --git a/templates_test/upsert.tpl b/templates_test/upsert.tpl index 25807cf..7e08819 100644 --- a/templates_test/upsert.tpl +++ b/templates_test/upsert.tpl @@ -1,4 +1,3 @@ -{{- if ne .DriverName "mssql" -}} {{- $tableNameSingular := .Table.Name | singular | titleCase -}} {{- $tableNamePlural := .Table.Name | plural | titleCase -}} {{- $varNamePlural := .Table.Name | plural | camelCase -}} @@ -48,5 +47,4 @@ func test{{$tableNamePlural}}Upsert(t *testing.T) { if count != 1 { t.Error("want one record, got:", count) } -} -{{- end}} +} \ No newline at end of file