Added helpers, select to DB interface and renames
* Added select and where helpers for the templates - These run at run time * Added select to the boil DB interface * Renamed some of the broken template names and fixed some templates
This commit is contained in:
parent
6228216ff6
commit
0768a89aa6
6 changed files with 192 additions and 29 deletions
104
boil/helpers.go
Normal file
104
boil/helpers.go
Normal file
|
@ -0,0 +1,104 @@
|
|||
package boil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// SelectNames returns the column names for a select statement
|
||||
func SelectNames(results interface{}) string {
|
||||
var names []string
|
||||
|
||||
structValue := reflect.ValueOf(results)
|
||||
if structValue.Kind() == reflect.Ptr {
|
||||
structValue = structValue.Elem()
|
||||
}
|
||||
|
||||
structType := structValue.Type()
|
||||
for i := 0; i < structValue.NumField(); i++ {
|
||||
field := structType.Field(i)
|
||||
var name string
|
||||
|
||||
if db := field.Tag.Get("db"); len(db) != 0 {
|
||||
name = db
|
||||
} else {
|
||||
name = goVarToSQLName(field.Name)
|
||||
}
|
||||
|
||||
names = append(names, name)
|
||||
}
|
||||
|
||||
return strings.Join(names, ", ")
|
||||
}
|
||||
|
||||
// Where returns the where clause for an sql statement
|
||||
func Where(columns map[string]interface{}) string {
|
||||
names := make([]string, 0, len(columns))
|
||||
|
||||
for c := range columns {
|
||||
names = append(names, c)
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
|
||||
for i, c := range names {
|
||||
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
||||
}
|
||||
|
||||
return strings.Join(names, " AND ")
|
||||
}
|
||||
|
||||
// WhereParams returns a list of sql parameter values for the query
|
||||
func WhereParams(columns map[string]interface{}) []interface{} {
|
||||
names := make([]string, 0, len(columns))
|
||||
results := make([]interface{}, 0, len(columns))
|
||||
|
||||
for c := range columns {
|
||||
names = append(names, c)
|
||||
}
|
||||
|
||||
sort.Strings(names)
|
||||
|
||||
for _, c := range names {
|
||||
results = append(results, columns[c])
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// goVarToSQLName converts a go variable name to a column name
|
||||
// example: HelloFriendID to hello_friend_id
|
||||
func goVarToSQLName(name string) string {
|
||||
str := &bytes.Buffer{}
|
||||
isUpper, upperStreak := false, false
|
||||
|
||||
for i := 0; i < len(name); i++ {
|
||||
c := rune(name[i])
|
||||
if unicode.IsDigit(c) || unicode.IsLower(c) {
|
||||
isUpper = false
|
||||
upperStreak = false
|
||||
|
||||
str.WriteRune(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if isUpper {
|
||||
upperStreak = true
|
||||
} else if i != 0 {
|
||||
str.WriteByte('_')
|
||||
}
|
||||
isUpper = true
|
||||
|
||||
if j := i + 1; j < len(name) && upperStreak && unicode.IsLower(rune(name[j])) {
|
||||
str.WriteByte('_')
|
||||
}
|
||||
|
||||
str.WriteRune(unicode.ToLower(c))
|
||||
}
|
||||
|
||||
return str.String()
|
||||
}
|
86
boil/helpers_test.go
Normal file
86
boil/helpers_test.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package boil
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
type testObj struct {
|
||||
ID int
|
||||
Name string `db:"TestHello"`
|
||||
HeadSize int
|
||||
}
|
||||
|
||||
func TestGoVarToSQLName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
In, Out string
|
||||
}{
|
||||
{"IDStruct", "id_struct"},
|
||||
{"WigglyBits", "wiggly_bits"},
|
||||
{"HoboIDFriend3333", "hobo_id_friend3333"},
|
||||
{"3333friend", "3333friend"},
|
||||
{"ID3ID", "id3_id"},
|
||||
{"Wei3rd", "wei3rd"},
|
||||
{"He3I3Test", "he3_i3_test"},
|
||||
{"He3ID3Test", "he3_id3_test"},
|
||||
{"HelloFriendID", "hello_friend_id"},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
if out := goVarToSQLName(test.In); out != test.Out {
|
||||
t.Errorf("%d) from: %q, want: %q, got: %q", i, test.In, test.Out, out)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSelectNames(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
o := testObj{
|
||||
Name: "bob",
|
||||
ID: 5,
|
||||
HeadSize: 23,
|
||||
}
|
||||
|
||||
result := SelectNames(o)
|
||||
if result != `id, TestHello, head_size` {
|
||||
t.Error("Result was wrong, got:", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhere(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
columns := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"id": 5,
|
||||
"date": time.Now(),
|
||||
}
|
||||
|
||||
result := Where(columns)
|
||||
|
||||
if result != `date=$1 AND id=$2 AND name=$3` {
|
||||
t.Error("Result was wrong, got:", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWhereParams(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
columns := map[string]interface{}{
|
||||
"name": "bob",
|
||||
"id": 5,
|
||||
}
|
||||
|
||||
result := WhereParams(columns)
|
||||
|
||||
if result[0].(int) != 5 {
|
||||
t.Error("Result[0] was wrong, got:", result[0])
|
||||
}
|
||||
|
||||
if result[1].(string) != "bob" {
|
||||
t.Error("Result[1] was wrong, got:", result[1])
|
||||
}
|
||||
}
|
|
@ -7,6 +7,7 @@ type DB interface {
|
|||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
Select(dest interface{}, query string, args ...interface{}) error
|
||||
}
|
||||
|
||||
// M type is for providing where filters to Where helpers.
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
func {{$tableName}}All(db boil.DB) ([]*{{$tableName}}, error) {
|
||||
var {{$varName}} []*{{$tableName}}
|
||||
|
||||
rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}}`)
|
||||
rows, err := db.Query(`SELECT {{selectParamNames .Table .Columns}} FROM {{.Table}}`)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: failed to query: %v", err)
|
||||
}
|
||||
|
|
|
@ -1,13 +0,0 @@
|
|||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}AllBy retrieves all records with the specified column values.
|
||||
func {{titleCase $tableName}}AllBy(db boil.DB, columns map[string]interface{}) ([]*{{titleCase $tableName}}, error) {
|
||||
{{$varName := camelCase $tableName -}}
|
||||
var {{$varName}} []*{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
}
|
||||
|
||||
return {{$varName}}, nil
|
||||
}
|
|
@ -1,15 +0,0 @@
|
|||
{{- $tableName := .Table -}}
|
||||
// {{titleCase $tableName}}FieldsAll retrieves the specified columns for all records.
|
||||
// Pass in a pointer to an object with `db` tags that match the column names you wish to retrieve.
|
||||
// For example: friendName string `db:"friend_name"`
|
||||
func {{titleCase $tableName}}FieldsAll(db boil.DB, results interface{}) error {
|
||||
{{$varName := camelCase $tableName -}}
|
||||
var {{$varName}} []*{{titleCase $tableName}}
|
||||
err := db.Select(&{{$varName}}, `SELECT {{selectParamNames $tableName .Columns}}`)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("models: unable to select from {{$tableName}}: %s", err)
|
||||
}
|
||||
|
||||
return {{$varName}}, nil
|
||||
}
|
Loading…
Reference in a new issue