From 912693a12433241bf746ae8b7bc5b25f3406d2e7 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Mon, 12 Sep 2016 23:28:23 -0700 Subject: [PATCH] Update parameter generation for mysql --- strmangle/strmangle.go | 22 ++++++++++++------- strmangle/strmangle_test.go | 23 ++++++++++++++++++++ templates/04_relationship_to_one.tpl | 2 +- templates/05_relationship_to_many.tpl | 4 ++-- templates/09_relationship_to_many_setops.tpl | 8 +++---- templates/11_find.tpl | 2 +- templates/12_insert.tpl | 2 +- templates/13_update.tpl | 5 ++++- templates/15_delete.tpl | 2 +- templates/17_exists.tpl | 2 +- templates_test/relationship_to_many.tpl | 4 ++-- templates_test/upsert.tpl | 3 +++ 12 files changed, 57 insertions(+), 22 deletions(-) diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index ab0147f..5f62aca 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -427,14 +427,19 @@ func Placeholders(indexPlaceholders bool, count int, start int, group int) strin // SetParamNames takes a slice of columns and returns a comma separated // list of parameter names for a template statement SET clause. // eg: "col1"=$1, "col2"=$2, "col3"=$3 -func SetParamNames(columns []string) string { +func SetParamNames(lq, rq string, start int, columns []string) string { buf := GetBuffer() defer PutBuffer(buf) for i, c := range columns { - buf.WriteString(fmt.Sprintf(`"%s"=$%d`, c, i+1)) + if start != 0 { + buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, i+start)) + } else { + buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq)) + } + if i < len(columns)-1 { - buf.WriteString(", ") + buf.WriteByte(',') } } @@ -444,15 +449,16 @@ func SetParamNames(columns []string) string { // WhereClause returns the where clause using start as the $ flag index // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" func WhereClause(lq, rq string, start int, cols []string) string { - if start == 0 { - panic("0 is not a valid start number for whereClause") - } - buf := GetBuffer() defer PutBuffer(buf) for i, c := range cols { - buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i)) + if start != 0 { + buf.WriteString(fmt.Sprintf(`%s%s%s=$%d`, lq, c, rq, start+i)) + } else { + buf.WriteString(fmt.Sprintf(`%s%s%s=?`, lq, c, rq)) + } + if i < len(cols)-1 { buf.WriteString(" AND ") } diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index f44ebfa..6d802d4 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -317,6 +317,28 @@ func TestPrefixStringSlice(t *testing.T) { } } +func TestSetParamNames(t *testing.T) { + t.Parallel() + + tests := []struct { + Cols []string + Start int + Should string + }{ + {Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=?,"col2"=?`}, + {Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`}, + {Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4,"col2"=$5`}, + {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4,"col2"=$5,"col3"=$6`}, + } + + for i, test := range tests { + r := SetParamNames(`"`, `"`, test.Start, test.Cols) + if r != test.Should { + t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test) + } + } +} + func TestWhereClause(t *testing.T) { t.Parallel() @@ -325,6 +347,7 @@ func TestWhereClause(t *testing.T) { Start int Should string }{ + {Cols: []string{"col1", "col2"}, Start: 0, Should: `"col1"=? AND "col2"=?`}, {Cols: []string{"col1"}, Start: 2, Should: `"col1"=$2`}, {Cols: []string{"col1", "col2"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5`}, {Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `"col1"=$4 AND "col2"=$5 AND "col3"=$6`}, diff --git a/templates/04_relationship_to_one.tpl b/templates/04_relationship_to_one.tpl index 06dac02..0fdfd87 100644 --- a/templates/04_relationship_to_one.tpl +++ b/templates/04_relationship_to_one.tpl @@ -10,7 +10,7 @@ func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}G(mods . // {{.Function.Name}} pointed to by the foreign key. func ({{.Function.Receiver}} *{{.LocalTable.NameGo}}) {{.Function.Name}}(exec boil.Executor, mods ...qm.QueryMod) ({{$varNameSingular}}Query) { queryMods := []qm.QueryMod{ - qm.Where("{{.ForeignTable.ColumnName}}=$1", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), + qm.Where("{{.ForeignTable.ColumnName}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{.Function.Receiver}}.{{.LocalTable.ColumnNameGo}}), } queryMods = append(queryMods, mods...) diff --git a/templates/05_relationship_to_many.tpl b/templates/05_relationship_to_many.tpl index 9d81847..b965377 100644 --- a/templates/05_relationship_to_many.tpl +++ b/templates/05_relationship_to_many.tpl @@ -33,11 +33,11 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) {{$rel.Function.Na {{if .ToJoinTable -}} queryMods = append(queryMods, qm.InnerJoin("{{.JoinTable | $dot.SchemaTable}} as {{id 1 | $dot.Quotes}} on {{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}} = {{id 1 | $dot.Quotes}}.{{.JoinForeignColumn | $dot.Quotes}}"), - qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}=$1", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + qm.Where("{{id 1 | $dot.Quotes}}.{{.JoinLocalColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{else -}} queryMods = append(queryMods, - qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}=$1", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), + qm.Where("{{id 0 | $dot.Quotes}}.{{.ForeignColumn | $dot.Quotes}}={{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}", {{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}), ) {{end}} diff --git a/templates/09_relationship_to_many_setops.tpl b/templates/09_relationship_to_many_setops.tpl index 59842f6..0280c31 100644 --- a/templates/09_relationship_to_many_setops.tpl +++ b/templates/09_relationship_to_many_setops.tpl @@ -39,7 +39,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function {{if .ToJoinTable -}} for _, rel := range related { - query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)" + query := "insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}, rel.{{$rel.ForeignTable.ColumnNameGo}}} if boil.DebugMode { @@ -96,10 +96,10 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Add{{$rel.Function // Sets related.R.{{$rel.Function.ForeignName}}'s {{$rel.Function.Name}} accordingly. func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Set{{$rel.Function.Name}}(exec boil.Executor, insert bool, related ...*{{$rel.ForeignTable.NameGo}}) error { {{if .ToJoinTable -}} - query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = $1" + query := "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{else -}} - query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = $1" + query := "update {{.ForeignTable | $dot.SchemaTable}} set {{.ForeignColumn | $dot.Quotes}} = null where {{.ForeignColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}}" values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} {{end -}} if boil.DebugMode { @@ -140,7 +140,7 @@ func ({{$rel.Function.Receiver}} *{{$rel.LocalTable.NameGo}}) Remove{{$rel.Funct var err error {{if .ToJoinTable -}} query := fmt.Sprintf( - "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = $1 and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", + "delete from {{.JoinTable | $dot.SchemaTable}} where {{.JoinLocalColumn | $dot.Quotes}} = {{if $dot.Dialect.IndexPlaceholders}}$1{{else}}?{{end}} and {{.JoinForeignColumn | $dot.Quotes}} in (%s)", strmangle.Placeholders(dialect.IndexPlaceholders, len(related), 1, 1), ) values := []interface{}{{"{"}}{{$rel.Function.Receiver}}.{{$rel.LocalTable.ColumnNameGo}}} diff --git a/templates/11_find.tpl b/templates/11_find.tpl index f7ad818..aede965 100644 --- a/templates/11_find.tpl +++ b/templates/11_find.tpl @@ -29,7 +29,7 @@ func Find{{$tableNameSingular}}(exec boil.Executor, {{$pkArgs}}, selectCols ...s sel = strings.Join(strmangle.IdentQuoteSlice(dialect.LQ, dialect.RQ, selectCols), ",") } query := fmt.Sprintf( - "select %s from {{.Table.Name | .SchemaTable}} where {{whereClause .LQ .RQ 1 .Table.PKey.Columns}}", sel, + "select %s from {{.Table.Name | .SchemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}", sel, ) q := boil.SQL(exec, query, {{$pkNames | join ", "}}) diff --git a/templates/12_insert.tpl b/templates/12_insert.tpl index a320871..d054124 100644 --- a/templates/12_insert.tpl +++ b/templates/12_insert.tpl @@ -69,7 +69,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string if len(cache.retMapping) != 0 { {{if .UseLastInsertID -}} - cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", 1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.retQuery = fmt.Sprintf("SELECT %s FROM {{$schemaTable}} WHERE %s", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}"), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns)) {{else -}} cache.query += fmt.Sprintf(" RETURNING {{.LQ}}%s{{.RQ}}", strings.Join(returnColumns, "{{.LQ}},{{.RQ}}")) {{end -}} diff --git a/templates/13_update.tpl b/templates/13_update.tpl index b3ade49..5b858c1 100644 --- a/templates/13_update.tpl +++ b/templates/13_update.tpl @@ -53,7 +53,10 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string if !cached { wl := strmangle.UpdateColumnSet({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns, whitelist) - cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", strmangle.SetParamNames(wl), strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns)) + cache.query = fmt.Sprintf("UPDATE {{$schemaTable}} SET %s WHERE %s", + strmangle.SetParamNames("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}1{{else}}0{{end}}, wl), + strmangle.WhereClause("{{.LQ}}", "{{.RQ}}", {{if .Dialect.IndexPlaceholders}}len(wl)+1{{else}}0{{end}}, {{$varNameSingular}}PrimaryKeyColumns), + ) cache.valueMapping, err = boil.BindMapping({{$varNameSingular}}Type, {{$varNameSingular}}Mapping, append(wl, {{$varNameSingular}}PrimaryKeyColumns...)) if err != nil { return err diff --git a/templates/15_delete.tpl b/templates/15_delete.tpl index 9734369..025c5ba 100644 --- a/templates/15_delete.tpl +++ b/templates/15_delete.tpl @@ -44,7 +44,7 @@ func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error { args := o.inPrimaryKeyArgs() - sql := "DELETE FROM {{$schemaTable}} WHERE {{whereClause .LQ .RQ 1 .Table.PKey.Columns}}" + sql := "DELETE FROM {{$schemaTable}} WHERE {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}}" if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) diff --git a/templates/17_exists.tpl b/templates/17_exists.tpl index 94f829b..b64da2f 100644 --- a/templates/17_exists.tpl +++ b/templates/17_exists.tpl @@ -7,7 +7,7 @@ func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error) { var exists bool - sql := "select exists(select 1 from {{$schemaTable}} where {{whereClause .LQ .RQ 1 .Table.PKey.Columns}} limit 1)" + sql := "select exists(select 1 from {{$schemaTable}} where {{if .Dialect.IndexPlaceholders}}{{whereClause .LQ .RQ 1 .Table.PKey.Columns}}{{else}}{{whereClause .LQ .RQ 0 .Table.PKey.Columns}}{{end}} limit 1)" if boil.DebugMode { fmt.Fprintln(boil.DebugWriter, sql) diff --git a/templates_test/relationship_to_many.tpl b/templates_test/relationship_to_many.tpl index a157f48..b147890 100644 --- a/templates_test/relationship_to_many.tpl +++ b/templates_test/relationship_to_many.tpl @@ -41,11 +41,11 @@ func test{{$rel.LocalTable.NameGo}}ToMany{{$rel.Function.Name}}(t *testing.T) { } {{if .ToJoinTable -}} - _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, b.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } - _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values ($1, $2)", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) + _, err = tx.Exec("insert into {{.JoinTable | $dot.SchemaTable}} ({{.JoinLocalColumn | $dot.Quotes}}, {{.JoinForeignColumn | $dot.Quotes}}) values {{if $dot.Dialect.IndexPlaceholders}}($1, $2){{else}}(?, ?){{end}}", a.{{$rel.LocalTable.ColumnNameGo}}, c.{{$rel.ForeignTable.ColumnNameGo}}) if err != nil { t.Fatal(err) } diff --git a/templates_test/upsert.tpl b/templates_test/upsert.tpl index cb48b87..00c667b 100644 --- a/templates_test/upsert.tpl +++ b/templates_test/upsert.tpl @@ -3,6 +3,9 @@ {{- $varNamePlural := .Table.Name | plural | camelCase -}} {{- $varNameSingular := .Table.Name | singular | camelCase -}} func test{{$tableNamePlural}}Upsert(t *testing.T) { + {{if not (eq .DriverName "postgres") -}} + t.Skip("not implemented for {{.DriverName}}") + {{end -}} t.Parallel() seed := randomize.NewSeed()