Clean up helpers, remove duplicate funcs

* Refactor DeleteAll for slice
This commit is contained in:
Patrick O'brien 2016-08-09 15:57:54 +10:00
parent 8e3c1d41da
commit 2ece7d14f6
13 changed files with 48 additions and 338 deletions

View file

@ -1,12 +1,8 @@
package boil
import (
"bytes"
"fmt"
"reflect"
"sort"
"strings"
"unicode"
"github.com/nullbio/sqlboiler/strmangle"
)
@ -116,150 +112,3 @@ Outer:
return c
}
// WherePrimaryKeyIn generates a "in" string for where queries
// For example: ("col1","col2") IN (($1,$2), ($3,$4))
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
in := &bytes.Buffer{}
if len(keyNames) == 0 {
return ""
}
in.WriteByte('(')
for i := 0; i < len(keyNames); i++ {
in.WriteString(`"` + keyNames[i] + `"`)
if i < len(keyNames)-1 {
in.WriteByte(',')
}
}
in.WriteString(") IN (")
c := 1
for i := 0; i < numRows; i++ {
for y := 0; y < len(keyNames); y++ {
if len(keyNames) > 1 && y == 0 {
in.WriteByte('(')
}
in.WriteString(fmt.Sprintf("$%d", c))
c++
if len(keyNames) > 1 && y == len(keyNames)-1 {
in.WriteByte(')')
}
if i != numRows-1 || y != len(keyNames)-1 {
in.WriteByte(',')
}
}
}
in.WriteByte(')')
return in.String()
}
// SelectNames returns the column names for a select statement
// Eg: "col1", "col2", "col3"
func SelectNames(results interface{}) string {
var names []string
structValue := reflect.Indirect(reflect.ValueOf(results))
structType := structValue.Type()
for i := 0; i < structValue.NumField(); i++ {
field := structType.Field(i)
var name string
if db := field.Tag.Get("db"); len(db) != 0 {
name = db
} else {
name = goVarToSQLName(field.Name)
}
names = append(names, fmt.Sprintf(`"%s"`, name))
}
return strings.Join(names, ", ")
}
// Update returns the column list for an update statement SET clause
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func Update(columns map[string]interface{}) string {
names := make([]string, 0, len(columns))
for c := range columns {
names = append(names, c)
}
sort.Strings(names)
for i, c := range names {
names[i] = fmt.Sprintf(`"%s"=$%d`, c, i+1)
}
return strings.Join(names, ", ")
}
// SetParamNames takes a slice of columns and returns a comma separated
// list of parameter names for a template statement SET clause.
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func SetParamNames(columns []string) string {
names := make([]string, 0, len(columns))
counter := 0
for _, c := range columns {
counter++
names = append(names, fmt.Sprintf(`"%s"=$%d`, c, counter))
}
return strings.Join(names, ", ")
}
// WherePrimaryKey returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WherePrimaryKey(start int, pkeys ...string) string {
var output string
for i, c := range pkeys {
output = fmt.Sprintf(`%s"%s"=$%d`, output, c, start)
start++
if i < len(pkeys)-1 {
output = fmt.Sprintf("%s AND ", output)
}
}
return output
}
// goVarToSQLName converts a go variable name to a column name
// example: HelloFriendID to hello_friend_id
func goVarToSQLName(name string) string {
str := &bytes.Buffer{}
isUpper, upperStreak := false, false
for i := 0; i < len(name); i++ {
c := rune(name[i])
if unicode.IsDigit(c) || unicode.IsLower(c) {
isUpper = false
upperStreak = false
str.WriteRune(c)
continue
}
if isUpper {
upperStreak = true
} else if i != 0 {
str.WriteByte('_')
}
isUpper = true
if j := i + 1; j < len(name) && upperStreak && unicode.IsLower(rune(name[j])) {
str.WriteByte('_')
}
str.WriteRune(unicode.ToLower(c))
}
return str.String()
}

View file

@ -217,94 +217,3 @@ func TestSortByKeys(t *testing.T) {
}
}
}
func TestWherePrimaryKeyIn(t *testing.T) {
t.Parallel()
x := WherePrimaryKeyIn(1, "aa")
expect := `("aa") IN ($1)`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
x = WherePrimaryKeyIn(2, "aa")
expect = `("aa") IN ($1,$2)`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
x = WherePrimaryKeyIn(3, "aa")
expect = `("aa") IN ($1,$2,$3)`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
x = WherePrimaryKeyIn(1, "aa", "bb")
expect = `("aa","bb") IN (($1,$2))`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
x = WherePrimaryKeyIn(2, "aa", "bb")
expect = `("aa","bb") IN (($1,$2),($3,$4))`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
x = WherePrimaryKeyIn(3, "aa", "bb")
expect = `("aa","bb") IN (($1,$2),($3,$4),($5,$6))`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
x = WherePrimaryKeyIn(4, "aa", "bb")
expect = `("aa","bb") IN (($1,$2),($3,$4),($5,$6),($7,$8))`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
x = WherePrimaryKeyIn(4, "aa", "bb", "cc")
expect = `("aa","bb","cc") IN (($1,$2,$3),($4,$5,$6),($7,$8,$9),($10,$11,$12))`
if x != expect {
t.Errorf("Expected %s, got %s\n", expect, x)
}
}
func TestGoVarToSQLName(t *testing.T) {
t.Parallel()
tests := []struct {
In, Out string
}{
{"IDStruct", "id_struct"},
{"WigglyBits", "wiggly_bits"},
{"HoboIDFriend3333", "hobo_id_friend3333"},
{"3333friend", "3333friend"},
{"ID3ID", "id3_id"},
{"Wei3rd", "wei3rd"},
{"He3I3Test", "he3_i3_test"},
{"He3ID3Test", "he3_id3_test"},
{"HelloFriendID", "hello_friend_id"},
}
for i, test := range tests {
if out := goVarToSQLName(test.In); out != test.Out {
t.Errorf("%d) from: %q, want: %q, got: %q", i, test.In, test.Out, out)
}
}
}
func TestSelectNames(t *testing.T) {
t.Parallel()
o := testObj{
Name: "bob",
ID: 5,
HeadSize: 23,
}
result := SelectNames(o)
if result != `"id", "TestHello", "head_size"` {
t.Error("Result was wrong, got:", result)
}
}

View file

@ -233,48 +233,32 @@ func Placeholders(count int, start int, group int) string {
return buf.String()
}
// WhereClause is a version of Where that binds multiple checks together
// with an or statement.
// WhereMultiple(1, 2, "a", "b") = "(a=$1 and b=$2) or (a=$3 and b=$4)"
func WhereClause(start, count int, cols []string) string {
if start == 0 {
panic("0 is not a valid start number for whereMultiple")
// SetParamNames takes a slice of columns and returns a comma separated
// list of parameter names for a template statement SET clause.
// eg: "col1"=$1, "col2"=$2, "col3"=$3
func SetParamNames(columns []string) string {
names := make([]string, 0, len(columns))
counter := 0
for _, c := range columns {
counter++
names = append(names, fmt.Sprintf(`"%s"=$%d`, c, counter))
}
buf := &bytes.Buffer{}
for i := 0; i < count; i++ {
if i != 0 {
buf.WriteString(" OR ")
}
buf.WriteByte('(')
for j, key := range cols {
if j != 0 {
buf.WriteString(" AND ")
}
fmt.Fprintf(buf, `"%s"=$%d`, key, start+i*len(cols)+j)
}
buf.WriteByte(')')
}
return buf.String()
return strings.Join(names, ", ")
}
// InClause generates SQL that could go inside an "IN ()" statement
// $1, $2, $3
func InClause(start, count int) string {
// WhereClause returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WhereClause(start int, cols []string) string {
if start == 0 {
panic("0 is not a valid start number for inClause")
panic("0 is not a valid start number for whereClause")
}
buf := &bytes.Buffer{}
for i := 0; i < count; i++ {
if i > 0 {
buf.WriteByte(',')
}
fmt.Fprintf(buf, "$%d", i+start)
ret := make([]string, len(cols))
for i, c := range cols {
ret[i] = fmt.Sprintf(`"%s"=$%d`, c, start+i)
}
return buf.String()
return strings.Join(ret, " AND ")
}
// DriverUsesLastInsertID returns whether the database driver supports the

View file

@ -307,57 +307,21 @@ func TestWhereClause(t *testing.T) {
tests := []struct {
Cols []string
Start int
Count int
Should string
}{
{Cols: []string{"col1", "col2"}, Start: 2, Count: 2, Should: `("col1"=$2 AND "col2"=$3) OR ("col1"=$4 AND "col2"=$5)`},
{Cols: []string{"col1", "col2"}, Start: 4, Count: 2, Should: `("col1"=$4 AND "col2"=$5) OR ("col1"=$6 AND "col2"=$7)`},
{Cols: []string{"col1", "col2", "col3"}, Start: 4, Count: 1, Should: `("col1"=$4 AND "col2"=$5 AND "col3"=$6)`},
{Cols: []string{"col1", "col2"}, Start: 2, Should: `("col1"=$2 AND "col2"=$3 AND "col1"=$4 AND "col2"=$5)`},
{Cols: []string{"col1", "col2"}, Start: 4, Should: `("col1"=$4 AND "col2"=$5 AND "col1"=$6 AND "col2"=$7)`},
{Cols: []string{"col1", "col2", "col3"}, Start: 4, Should: `("col1"=$4 AND "col2"=$5 AND "col3"=$6)`},
}
for i, test := range tests {
r := WhereClause(test.Start, test.Count, test.Cols)
r := WhereClause(test.Start, test.Cols)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s", i, test.Should, r)
}
}
}
func TestWhereMultiplePanic(t *testing.T) {
t.Parallel()
defer func() {
if recover() == nil {
t.Error("did not panic")
}
}()
WhereClause(0, 0, nil)
}
func TestInClause(t *testing.T) {
t.Parallel()
if str := InClause(1, 2); str != `$1,$2` {
t.Error("wrong output:", str)
}
if str := InClause(2, 2); str != `$2,$3` {
t.Error("wrong output:", str)
}
}
func TestInClausePanic(t *testing.T) {
t.Parallel()
defer func() {
if recover() == nil {
t.Error("did not panic")
}
}()
InClause(0, 0)
}
func TestSubstring(t *testing.T) {
t.Parallel()

View file

@ -29,7 +29,7 @@ func {{$tableNameSingular}}Find(exec boil.Executor, {{$pkArgs}}, selectCols ...s
sel = strings.Join(strmangle.IdentQuoteSlice(selectCols), ",")
}
sql := fmt.Sprintf(
`select %s from "{{.Table.Name}}" where {{whereClause 1 1 .Table.PKey.Columns}}`, sel,
`select %s from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}}`, sel,
)
q := boil.SQL(sql, {{$pkNames | join ", "}})
boil.SetExecutor(q, exec)

View file

@ -49,7 +49,7 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string
lastId, err := result.lastInsertId()
if err != nil || lastId == 0 {
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, 1, wl))
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, `","`), strmangle.WhereClause(1, wl))
rows, err := exec.Query(sel, boil.GetStructValues(o, wl...)...)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)

View file

@ -44,7 +44,7 @@ func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string
wl := o.generateUpdateColumns(whitelist...)
if len(wl) != 0 {
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(wl), boil.WherePrimaryKey(len(wl)+1, {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}}))
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, strmangle.SetParamNames(wl), strmangle.WhereClause(len(wl)+1, {{$varNameSingular}}PrimaryKeyColumns))
values = boil.GetStructValues(o, wl...)
values = append(values, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})
_, err = exec.Exec(query, values...)

View file

@ -30,7 +30,7 @@ func (o *{{$tableNameSingular}}) Delete(exec boil.Executor) error {
mods = append(mods,
qm.From("{{.Table.Name}}"),
qm.Where(`{{whereClause 1 1 .Table.PKey.Columns}}`, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}),
qm.Where(`{{whereClause 1 .Table.PKey.Columns}}`, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}}),
)
query := NewQuery(exec, mods...)
@ -97,24 +97,28 @@ func (o {{$tableNameSingular}}Slice) DeleteAll(exec boil.Executor) error {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for delete all")
}
var mods []qm.QueryMod
if len(o) == 0 {
return nil
}
args := o.inPrimaryKeyArgs()
in := boil.WherePrimaryKeyIn(len(o), {{.Table.PKey.Columns | stringMap .StringFuncs.quoteWrap | join ", "}})
mods = append(mods,
qm.From("{{.Table.Name}}"),
qm.Where(in, args...),
sql := fmt.Sprintf(
`DELETE FROM {{.Table.Name}} WHERE (%s) IN (%s)`,
strings.Join({{$varNameSingular}}PrimaryKeyColumns, ","),
strmangle.Placeholders(len(o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
)
query := NewQuery(exec, mods...)
boil.SetDelete(query)
q := boil.SQL(sql, args...)
boil.SetExecutor(q, exec)
_, err := boil.ExecQuery(query)
_, err := boil.ExecQuery(q)
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to delete all from {{$varNameSingular}} slice: %s", err)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, sql)
fmt.Fprintln(boil.DebugWriter, args)
}

View file

@ -59,19 +59,19 @@ func (o *{{$tableNameSingular}}Slice) ReloadAllG() error {
// ReloadAll refetches every row with matching primary key column values
// and overwrites the original object slice with the newly updated slice.
func (o *{{$tableNameSingular}}Slice) ReloadAll(exec boil.Executor) error {
if o == nil {
return errors.New("{{.PkgName}}: no {{$tableNameSingular}} slice provided for reload all")
}
if len(*o) == 0 {
return nil
}
{{$varNamePlural}} := {{$tableNameSingular}}Slice{}
var args []interface{}
for i := 0; i < len(*o); i++ {
args = append(args, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "(*o)[i]." | join ", "}})
}
args := o.inPrimaryKeyArgs()
sql := fmt.Sprintf(
`select {{.Table.Name}}.* from {{.Table.Name}} where (%s) in (%s)`,
`SELECT {{.Table.Name}}.* FROM {{.Table.Name}} WHERE (%s) IN (%s)`,
strings.Join({{$varNameSingular}}PrimaryKeyColumns, ","),
strmangle.Placeholders(len(*o) * len({{$varNameSingular}}PrimaryKeyColumns), 1, len({{$varNameSingular}}PrimaryKeyColumns)),
)

View file

@ -7,7 +7,7 @@ func {{$tableNameSingular}}Exists(exec boil.Executor, {{$pkArgs}}) (bool, error)
var exists bool
row := exec.QueryRow(
`select exists(select 1 from "{{.Table.Name}}" where {{whereClause 1 1 .Table.PKey.Columns}} limit 1)`,
`select exists(select 1 from "{{.Table.Name}}" where {{whereClause 1 .Table.PKey.Columns}} limit 1)`,
{{$pkNames | join ", "}},
)

View file

@ -24,7 +24,7 @@ func Test{{$tableNamePlural}}Exists(t *testing.T) {
t.Errorf("Expected {{$tableNameSingular}}ExistsG to return true, but got false.")
}
whereClause := strmangle.WhereClause(1, 1, {{$varNameSingular}}PrimaryKeyColumns)
whereClause := strmangle.WhereClause(1, {{$varNameSingular}}PrimaryKeyColumns)
e, err = {{$tableNamePlural}}G(qm.Where(whereClause, boil.GetStructValues(o, {{$varNameSingular}}PrimaryKeyColumns...)...)).Exists()
if err != nil {
t.Errorf("Unable to check if {{$tableNameSingular}} exists: %s", err)

View file

@ -16,7 +16,7 @@ func Test{{$tableNamePlural}}Bind(t *testing.T) {
j := {{$tableNameSingular}}{}
err = {{$tableNamePlural}}G(qm.Where(`{{whereClause 1 1 .Table.PKey.Columns}}`, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})).Bind(&j)
err = {{$tableNamePlural}}G(qm.Where(`{{whereClause 1 .Table.PKey.Columns}}`, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "o." | join ", "}})).Bind(&j)
if err != nil {
t.Errorf("Unable to call Bind on {{$tableNameSingular}} single object: %s", err)
}

View file

@ -29,7 +29,7 @@ func Test{{$tableNamePlural}}Select(t *testing.T) {
t.Errorf("Unable to insert item {{$tableNameSingular}}:\n%#v\nErr: %s", item, err)
}
err = {{$tableNamePlural}}G(qm.Select({{$varNameSingular}}AutoIncrementColumns...), qm.Where(`{{whereClause 1 1 .Table.PKey.Columns}}`, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "item." | join ", "}})).Bind(x)
err = {{$tableNamePlural}}G(qm.Select({{$varNameSingular}}AutoIncrementColumns...), qm.Where(`{{whereClause 1 .Table.PKey.Columns}}`, {{.Table.PKey.Columns | stringMap .StringFuncs.titleCase | prefixStringSlice "item." | join ", "}})).Bind(x)
if err != nil {
t.Errorf("Unable to select insert results with bind: %s", err)
}