diff --git a/boil/helpers.go b/boil/helpers.go new file mode 100644 index 0000000..0b13b3c --- /dev/null +++ b/boil/helpers.go @@ -0,0 +1,104 @@ +package boil + +import ( + "bytes" + "fmt" + "reflect" + "sort" + "strings" + "unicode" +) + +// SelectNames returns the column names for a select statement +func SelectNames(results interface{}) string { + var names []string + + structValue := reflect.ValueOf(results) + if structValue.Kind() == reflect.Ptr { + structValue = structValue.Elem() + } + + structType := structValue.Type() + for i := 0; i < structValue.NumField(); i++ { + field := structType.Field(i) + var name string + + if db := field.Tag.Get("db"); len(db) != 0 { + name = db + } else { + name = goVarToSQLName(field.Name) + } + + names = append(names, name) + } + + return strings.Join(names, ", ") +} + +// Where returns the where clause for an sql statement +func Where(columns map[string]interface{}) string { + names := make([]string, 0, len(columns)) + + for c := range columns { + names = append(names, c) + } + + sort.Strings(names) + + for i, c := range names { + names[i] = fmt.Sprintf("%s=$%d", c, i+1) + } + + return strings.Join(names, " AND ") +} + +// WhereParams returns a list of sql parameter values for the query +func WhereParams(columns map[string]interface{}) []interface{} { + names := make([]string, 0, len(columns)) + results := make([]interface{}, 0, len(columns)) + + for c := range columns { + names = append(names, c) + } + + sort.Strings(names) + + for _, c := range names { + results = append(results, columns[c]) + } + + return results +} + +// goVarToSQLName converts a go variable name to a column name +// example: HelloFriendID to hello_friend_id +func goVarToSQLName(name string) string { + str := &bytes.Buffer{} + isUpper, upperStreak := false, false + + for i := 0; i < len(name); i++ { + c := rune(name[i]) + if unicode.IsDigit(c) || unicode.IsLower(c) { + isUpper = false + upperStreak = false + + str.WriteRune(c) + continue + } + + if isUpper { + upperStreak = true + } else if i != 0 { + str.WriteByte('_') + } + isUpper = true + + if j := i + 1; j < len(name) && upperStreak && unicode.IsLower(rune(name[j])) { + str.WriteByte('_') + } + + str.WriteRune(unicode.ToLower(c)) + } + + return str.String() +} diff --git a/boil/helpers_test.go b/boil/helpers_test.go new file mode 100644 index 0000000..fe72562 --- /dev/null +++ b/boil/helpers_test.go @@ -0,0 +1,86 @@ +package boil + +import ( + "testing" + "time" +) + +type testObj struct { + ID int + Name string `db:"TestHello"` + HeadSize int +} + +func TestGoVarToSQLName(t *testing.T) { + t.Parallel() + + tests := []struct { + In, Out string + }{ + {"IDStruct", "id_struct"}, + {"WigglyBits", "wiggly_bits"}, + {"HoboIDFriend3333", "hobo_id_friend3333"}, + {"3333friend", "3333friend"}, + {"ID3ID", "id3_id"}, + {"Wei3rd", "wei3rd"}, + {"He3I3Test", "he3_i3_test"}, + {"He3ID3Test", "he3_id3_test"}, + {"HelloFriendID", "hello_friend_id"}, + } + + for i, test := range tests { + if out := goVarToSQLName(test.In); out != test.Out { + t.Errorf("%d) from: %q, want: %q, got: %q", i, test.In, test.Out, out) + } + } +} + +func TestSelectNames(t *testing.T) { + t.Parallel() + + o := testObj{ + Name: "bob", + ID: 5, + HeadSize: 23, + } + + result := SelectNames(o) + if result != `id, TestHello, head_size` { + t.Error("Result was wrong, got:", result) + } +} + +func TestWhere(t *testing.T) { + t.Parallel() + + columns := map[string]interface{}{ + "name": "bob", + "id": 5, + "date": time.Now(), + } + + result := Where(columns) + + if result != `date=$1 AND id=$2 AND name=$3` { + t.Error("Result was wrong, got:", result) + } +} + +func TestWhereParams(t *testing.T) { + t.Parallel() + + columns := map[string]interface{}{ + "name": "bob", + "id": 5, + } + + result := WhereParams(columns) + + if result[0].(int) != 5 { + t.Error("Result[0] was wrong, got:", result[0]) + } + + if result[1].(string) != "bob" { + t.Error("Result[1] was wrong, got:", result[1]) + } +} diff --git a/boil/types.go b/boil/types.go index 447e69c..23109c2 100644 --- a/boil/types.go +++ b/boil/types.go @@ -7,6 +7,7 @@ type DB interface { Exec(query string, args ...interface{}) (sql.Result, error) Query(query string, args ...interface{}) (*sql.Rows, error) QueryRow(query string, args ...interface{}) *sql.Row + Select(dest interface{}, query string, args ...interface{}) error } // M type is for providing where filters to Where helpers. diff --git a/cmds/templates/all.tpl b/cmds/templates/all.tpl index d5ec46b..97371ca 100644 --- a/cmds/templates/all.tpl +++ b/cmds/templates/all.tpl @@ -4,7 +4,7 @@ func {{$tableName}}All(db boil.DB) ([]*{{$tableName}}, error) { var {{$varName}} []*{{$tableName}} - rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}}`) + rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}} FROM {{.Table}}`) if err != nil { return nil, fmt.Errorf("models: failed to query: %v", err) } diff --git a/cmds/templates/allby.tpl b/cmds/templates/allby.tpl deleted file mode 100644 index 6a22ccb..0000000 --- a/cmds/templates/allby.tpl +++ /dev/null @@ -1,13 +0,0 @@ -{{- $tableName := .Table -}} -// {{titleCase $tableName}}AllBy retrieves all records with the specified column values. -func {{titleCase $tableName}}AllBy(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) { - {{$varName := camelCase $tableName -}} - var {{$varName}} []*{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`) - - if err != nil { - return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) - } - - return {{$varName}}, nil -} diff --git a/cmds/templates/fieldsall.tpl b/cmds/templates/fieldsall.tpl deleted file mode 100644 index 5796b16..0000000 --- a/cmds/templates/fieldsall.tpl +++ /dev/null @@ -1,15 +0,0 @@ -{{- $tableName := .Table -}} -// {{titleCase $tableName}}FieldsAll retrieves the specified columns for all records. -// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve. -// For example: friendName string `db:"friend_name"` -func {{titleCase $tableName}}FieldsAll(db boil.DB, results interface{}) error { - {{$varName := camelCase $tableName -}} - var {{$varName}} []*{{titleCase $tableName}} - err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`) - - if err != nil { - return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err) - } - - return {{$varName}}, nil -}