From e3f319346f1ee688dd5f02db62364aac7c14fc0e Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Thu, 11 Aug 2016 18:23:47 +1000 Subject: [PATCH] Finish UpdateAll query builder * Add modifiers to delete builder * Update golden file tests * Add startAt to whereClause --- boil/_fixtures/09.sql | 2 +- boil/_fixtures/10.sql | 1 + boil/query_builders.go | 119 +++++++++++++++++++++------ boil/query_builders_test.go | 16 +++- strmangle/strmangle.go | 10 ++- templates/09_update.tpl | 74 ++++++++++++++++- templates/singleton/boil_helpers.tpl | 3 - templates/singleton/boil_types.tpl | 3 + templates_test/insert.tpl | 5 +- templates_test/update.tpl | 47 +++++++++++ 10 files changed, 240 insertions(+), 40 deletions(-) create mode 100644 boil/_fixtures/10.sql diff --git a/boil/_fixtures/09.sql b/boil/_fixtures/09.sql index 68555a3..7cf3a6f 100644 --- a/boil/_fixtures/09.sql +++ b/boil/_fixtures/09.sql @@ -1 +1 @@ -DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE ((id=$1 and thing=$2) or stuff=$3); \ No newline at end of file +DELETE FROM thing happy, upset as "sad", "fun", thing as stuff, "angry" as mad WHERE ((id=$1 and thing=$2) or stuff=$3) LIMIT 5; \ No newline at end of file diff --git a/boil/_fixtures/10.sql b/boil/_fixtures/10.sql new file mode 100644 index 0000000..4b64646 --- /dev/null +++ b/boil/_fixtures/10.sql @@ -0,0 +1 @@ +UPDATE thing happy, "fun", "stuff" SET ("col1", "col2", "fun"."col3") VALUES ($1, $2, $3) WHERE (aa=$4 or bb=$5) OR (cc=$6) LIMIT 5; \ No newline at end of file diff --git a/boil/query_builders.go b/boil/query_builders.go index 8f2c09c..7eff299 100644 --- a/boil/query_builders.go +++ b/boil/query_builders.go @@ -70,9 +70,65 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { fmt.Fprintf(buf, " INNER JOIN %s", j.clause) } - where, args := whereClause(q) + where, args := whereClause(q, 1) buf.WriteString(where) + writeModifiers(q, buf) + + buf.WriteByte(';') + return buf, args +} + +func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { + buf := &bytes.Buffer{} + + buf.WriteString("DELETE FROM ") + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + + where, args := whereClause(q, 1) + buf.WriteString(where) + + writeModifiers(q, buf) + + buf.WriteByte(';') + + return buf, args +} + +func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { + buf := &bytes.Buffer{} + + buf.WriteString("UPDATE ") + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) + + cols := make([]string, len(q.update)) + args := make([]interface{}, len(q.update)) + + count := 0 + for name, value := range q.update { + cols[count] = strmangle.IdentQuote(name) + args[count] = value + count++ + } + + buf.WriteString(fmt.Sprintf( + " SET (%s) VALUES (%s)", + strings.Join(cols, ", "), + strmangle.Placeholders(len(cols), 1, 1)), + ) + + where, whereArgs := whereClause(q, len(args)+1) + buf.WriteString(where) + args = append(args, whereArgs...) + + writeModifiers(q, buf) + + buf.WriteByte(';') + + return buf, args +} + +func writeModifiers(q *Query, buf *bytes.Buffer) { if len(q.groupBy) != 0 { fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", ")) } @@ -92,9 +148,6 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { if q.offset != 0 { fmt.Fprintf(buf, " OFFSET %d", q.offset) } - - buf.WriteByte(';') - return buf, args } func writeStars(q *Query) []string { @@ -144,28 +197,12 @@ func writeAsStatements(q *Query) []string { return cols } -func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { - buf := &bytes.Buffer{} - - buf.WriteString("DELETE FROM ") - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", ")) - - where, args := whereClause(q) - buf.WriteString(where) - - buf.WriteByte(';') - - return buf, args -} - -func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { - buf := &bytes.Buffer{} - - buf.WriteByte(';') - return buf, nil -} - -func whereClause(q *Query) (string, []interface{}) { +// whereClause parses a where slice and converts it into a +// single WHERE clause like: +// WHERE (a=$1) AND (b=$2) +// +// startAt specifies what number placeholders start at +func whereClause(q *Query, startAt int) (string, []interface{}) { if len(q.where) == 0 { return "", nil } @@ -211,7 +248,35 @@ func whereClause(q *Query) (string, []interface{}) { paramIndex++ } - return paramBuf.String(), args + return convertQuestionMarks(buf.String(), startAt), args +} + +func convertQuestionMarks(clause string, startAt int) string { + if startAt == 0 { + panic("Not a valid start number.") + } + + paramBuf := &bytes.Buffer{} + paramIndex := 0 + + for ; ; startAt++ { + if paramIndex >= len(clause) { + break + } + + clause = clause[paramIndex:] + paramIndex = strings.IndexByte(clause, '?') + + if paramIndex == -1 { + paramBuf.WriteString(clause) + break + } + + paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt)) + paramIndex++ + } + + return paramBuf.String() } // identifierMapping creates a map of all identifiers to potential model names diff --git a/boil/query_builders_test.go b/boil/query_builders_test.go index 160f9c0..65f50bc 100644 --- a/boil/query_builders_test.go +++ b/boil/query_builders_test.go @@ -59,7 +59,21 @@ func TestBuildQuery(t *testing.T) { where: []where{ where{clause: "(id=? and thing=?) or stuff=?", args: []interface{}{}}, }, + limit: 5, }, nil}, + {&Query{ + from: []string{"thing happy", `"fun"`, `stuff`}, + update: map[string]interface{}{ + "col1": 1, + `"col2"`: 2, + `"fun".col3`: 3, + }, + where: []where{ + where{clause: "aa=? or bb=?", orSeparator: true, args: []interface{}{4, 5}}, + where{clause: "cc=?", args: []interface{}{6}}, + }, + limit: 5, + }, []interface{}{1, 2, 3, 4, 5, 6}}, } for i, test := range tests { @@ -297,7 +311,7 @@ func TestWhereClause(t *testing.T) { } for i, test := range tests { - result, _ := whereClause(&test.q) + result, _ := whereClause(&test.q, 1) if result != test.expect { t.Errorf("%d) Mismatch between expect and result:\n%s\n%s\n", i, test.expect, result) } diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index f20e7ae..076d1ca 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -208,20 +208,24 @@ func PrefixStringSlice(str string, strs []string) []string { } // Placeholders generates the SQL statement placeholders for in queries. -// For example, ($1,$2,$3),($4,$5,$6) etc. +// For example, ($1, $2, $3), ($4, $5, $6) etc. // It will start counting placeholders at "start". func Placeholders(count int, start int, group int) string { var buf bytes.Buffer + if start == 0 || group == 0 { + panic("Invalid start or group numbers supplied.") + } + if group > 1 { buf.WriteByte('(') } for i := 0; i < count; i++ { if i != 0 { if group > 1 && i%group == 0 { - buf.WriteString(`),(`) + buf.WriteString("), (") } else { - buf.WriteByte(',') + buf.WriteString(", ") } } buf.WriteString(fmt.Sprintf("$%d", start+i)) diff --git a/templates/09_update.tpl b/templates/09_update.tpl index 2a9c768..2aa8ec5 100644 --- a/templates/09_update.tpl +++ b/templates/09_update.tpl @@ -68,7 +68,14 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string return nil } -// UpdateAll updates all rows with matching column names. +// UpdateAllP updates all rows with matching column names, and panics on error. +func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { + if err := q.UpdateAll(cols); err != nil { + panic(boil.WrapErr(err)) + } +} + +// UpdateAll updates all rows with the specified column values. func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { boil.SetUpdate(q.Query, cols) @@ -80,13 +87,72 @@ func (q {{$varNameSingular}}Query) UpdateAll(cols M) error { return nil } -// UpdateAllP updates all rows with matching column names, and panics on error. -func (q {{$varNameSingular}}Query) UpdateAllP(cols M) { - if err := q.UpdateAll(cols); err != nil { +// UpdateAllG updates all rows with the specified column values. +func (o {{$tableNameSingular}}Slice) UpdateAllG(cols M) error { + return o.UpdateAll(boil.GetDB(), cols) +} + +// UpdateAllGP updates all rows with the specified column values, and panics on error. +func (o {{$tableNameSingular}}Slice) UpdateAllGP(cols M) { + if err := o.UpdateAll(boil.GetDB(), cols); err != nil { panic(boil.WrapErr(err)) } } +// UpdateAllP updates all rows with the specified column values, and panics on error. +func (o {{$tableNameSingular}}Slice) UpdateAllP(exec boil.Executor, cols M) { + if err := o.UpdateAll(exec, cols); err != nil { + panic(boil.WrapErr(err)) + } +} + +// UpdateAll updates all rows with the specified column values, using an executor. +func (o {{$tableNameSingular}}Slice) UpdateAll(exec boil.Executor, cols M) error { + if o == nil { + return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for update all") + } + + if len(o) == 0 { + return nil + } + + colNames := make([]string, len(cols)) + var args []interface{} + + count := 0 + for name, value := range cols { + colNames[count] = name + args = append(args, value) + count++ + } + + // Append all of the primary key values for each column + args = append(args, o.inPrimaryKeyArgs()) + + sql := fmt.Sprintf( + `UPDATE {{.Table.Name}} SET (%s) VALUES (%s) WHERE (%s) IN (%s)`, + strings.Join(colNames, ", "), + strmangle.Placeholders(len(args), 1, 1), + strings.Join({{$varNameSingular}}PrimaryKeyColumns, ","), + strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), len(args)+1, len({{$varNameSingular}}PrimaryKeyColumns)), + ) + + q := boil.SQL(sql, args...) + boil.SetExecutor(q, exec) + + _, err := boil.ExecQuery(q) + if err != nil { + return fmt.Errorf("{{.PkgName}}: unable to update all in {{$varNameSingular}} slice: %s", err) + } + + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, sql) + fmt.Fprintln(boil.DebugWriter, ) + } + + return nil +} + // generateUpdateColumns generates the whitelist columns for an update statement // if a whitelist is supplied, it's returned // if a whitelist is missing then we begin with all columns diff --git a/templates/singleton/boil_helpers.tpl b/templates/singleton/boil_helpers.tpl index 790c238..6cb607d 100644 --- a/templates/singleton/boil_helpers.tpl +++ b/templates/singleton/boil_helpers.tpl @@ -1,6 +1,3 @@ -// M type is for providing where filters to Where helpers. -type M map[string]interface{} - // NewQueryG initializes a new Query using the passed in QueryMods func NewQueryG(mods ...qm.QueryMod) *boil.Query { return NewQuery(boil.GetDB(), mods...) diff --git a/templates/singleton/boil_types.tpl b/templates/singleton/boil_types.tpl index cb83863..c5b1075 100644 --- a/templates/singleton/boil_types.tpl +++ b/templates/singleton/boil_types.tpl @@ -1,3 +1,6 @@ +// M type is for providing columns and column values to UpdateAll. +type M map[string]interface{} + type upsertData struct { conflict []string update []string diff --git a/templates_test/insert.tpl b/templates_test/insert.tpl index fb26ff2..19fcde6 100644 --- a/templates_test/insert.tpl +++ b/templates_test/insert.tpl @@ -28,8 +28,11 @@ func Test{{$tableNamePlural}}Insert(t *testing.T) { j := make({{$tableNameSingular}}Slice, 3) // Perform all Find queries and assign result objects to slice for comparison - for i := 0; i < len(j); i++ { + for i := 0; i < len(o); i++ { j[i], err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}}) + if err != nil { + t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err) + } {{$varNameSingular}}CompareVals(o[i], j[i], t) } diff --git a/templates_test/update.tpl b/templates_test/update.tpl index 7563797..4baac2b 100644 --- a/templates_test/update.tpl +++ b/templates_test/update.tpl @@ -41,3 +41,50 @@ func Test{{$tableNamePlural}}Update(t *testing.T) { {{$varNamePlural}}DeleteAllRows(t) } + +func Test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) { + var err error + + // insert random columns to test UpdateAll + o := make({{$tableNameSingular}}Slice, 3) + j := make({{$tableNameSingular}}Slice, 3) + + if err = boil.RandomizeSlice(&o, {{$varNameSingular}}DBTypes, false); err != nil { + t.Errorf("Unable to randomize {{$tableNameSingular}} slice: %s", err) + } + + for i := 0; i < len(o); i++ { + if err = o[i].InsertG(); err != nil { + t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", o[i], err) + } + } + + vals := M{} + + tmp := {{$tableNameSingular}}{} + if err = boil.RandomizeStruct(&tmp, {{$varNameSingular}}DBTypes, false, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Errorf("Unable to randomize struct {{$tableNameSingular}}: %s", err) + } + + // Build the columns and column values from the randomized struct + tmpVal := reflect.Indirect(reflect.ValueOf(tmp)) + nonPrimKeys := boil.SetComplement({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) + for _, col := range nonPrimKeys { + vals[col] = tmpVal.FieldByName(strmangle.TitleCase(col)).Interface() + } + + err = o.UpdateAllG(vals) + if err != nil { + t.Errorf("Failed to update all for {{$tableNameSingular}}: %s", err) + } + + for i := 0; i < len(o); i++ { + j[i], err = {{$tableNameSingular}}FindG({{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o[i]." | join ", "}}) + if err != nil { + t.Errorf("Unable to find {{$tableNameSingular}} row: %s", err) + } + {{$varNameSingular}}CompareVals(o[i], &tmp, t) + } + + {{$varNamePlural}}DeleteAllRows(t) +}