Clean up ridiculous amount of strmangle

This commit is contained in:
Aaron L 2016-06-19 18:45:33 -07:00
parent 8dd3eb72ed
commit 7c2a04ba4d
4 changed files with 133 additions and 441 deletions

View file

@ -1,3 +1,7 @@
// Package strmangle is used exclusively by the templates in sqlboiler.
// There are many helper functions to deal with dbdrivers.* values as well
// as string manipulation. Because it is focused on pipelining inside templates
// you will see some odd parameter ordering.
package strmangle
import (
@ -44,22 +48,6 @@ func TitleCase(name string) string {
return strings.Join(splits, "")
}
// TitleCaseSingular changes a snake-case variable name
// to a go styled object variable name of "ColumnName".
// titleCaseSingular also converts the last word in the
// variable name to a singularized version of itself.
func TitleCaseSingular(name string) string {
return TitleCase(Singular(name))
}
// TitleCasePlural changes a snake-case variable name
// to a go styled object variable name of "ColumnName".
// titleCasePlural also converts the last word in the
// variable name to a pluralized version of itself.
func TitleCasePlural(name string) string {
return TitleCase(Plural(name))
}
// CamelCase takes a variable name in the format of "var_name" and converts
// it into a go styled variable name of "varName".
// camelCase also fully uppercases "ID" components of names, for example
@ -83,91 +71,28 @@ func CamelCase(name string) string {
return strings.Join(splits, "")
}
// CamelCaseSingular takes a variable name in the format of "var_name" and converts
// it into a go styled variable name of "varName".
// CamelCaseSingular also converts the last word in the
// variable name to a singularized version of itself.
func CamelCaseSingular(name string) string {
return CamelCase(Singular(name))
}
// StringMap maps a function over a slice of strings.
func StringMap(modifier func(string) string, strs []string) []string {
ret := make([]string, len(strs))
// CamelCasePlural takes a variable name in the format of "var_name" and converts
// it into a go styled variable name of "varName".
// CamelCasePlural also converts the last word in the
// variable name to a pluralized version of itself.
func CamelCasePlural(name string) string {
return CamelCase(Plural(name))
}
// CamelCaseCommaList generates a list of comma seperated camel cased column names
// example: thingName, o.stuffName, etc
func CamelCaseCommaList(prefix string, cols []string) string {
var output []string
for _, c := range cols {
output = append(output, prefix+CamelCase(c))
for i, str := range strs {
ret[i] = modifier(str)
}
return strings.Join(output, ", ")
}
// TitleCaseCommaList generates a list of comma seperated title cased column names
// example: o.ThingName, o.Stuff, ThingStuff, etc
func TitleCaseCommaList(prefix string, cols []string) string {
var output []string
for _, c := range cols {
output = append(output, prefix+TitleCase(c))
}
return strings.Join(output, ", ")
return ret
}
// MakeDBName takes a table name in the format of "table_name" and a
// column name in the format of "column_name" and returns a name used in the
// `db:""` component of an object in the format of "table_name_column_name"
func MakeDBName(tableName, colName string) string {
return tableName + "_" + colName
return fmt.Sprintf("%s_%s", tableName, colName)
}
// UpdateParamNames takes a []Column and returns a comma seperated
// list of parameter names for the update statement template SET clause.
// eg: col1=$1,col2=$2,col3=$3
// Note: updateParamNames will exclude the PRIMARY KEY column.
func UpdateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string {
names := make([]string, 0, len(columns))
counter := 0
for _, c := range columns {
if IsPrimaryKey(c.Name, pkeyColumns) {
continue
}
counter++
names = append(names, fmt.Sprintf("%s=$%d", c.Name, counter))
}
return strings.Join(names, ",")
}
// UpdateParamVariables takes a prefix and a []Columns and returns a
// comma seperated list of parameter variable names for the update statement.
// eg: prefix("o."), column("name_id") -> "o.NameID, ..."
// Note: updateParamVariables will exclude the PRIMARY KEY column.
func UpdateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string {
names := make([]string, 0, len(columns))
for _, c := range columns {
if IsPrimaryKey(c.Name, pkeyColumns) {
continue
}
names = append(names, fmt.Sprintf("%s%s", prefix, TitleCase(c.Name)))
}
return strings.Join(names, ", ")
}
// IsPrimaryKey checks if the column is found in the primary key columns
func IsPrimaryKey(col string, pkeyCols []string) bool {
for _, pkey := range pkeyCols {
if pkey == col {
// HasElement checks to see if the string is found in the string slice
func HasElement(str string, slice []string) bool {
for _, s := range slice {
if str == s {
return true
}
}
@ -175,87 +100,33 @@ func IsPrimaryKey(col string, pkeyCols []string) bool {
return false
}
// InsertParamNames takes a []Column and returns a comma seperated
// list of parameter names for the insert statement template.
func InsertParamNames(columns []dbdrivers.Column) string {
names := make([]string, len(columns))
for i, c := range columns {
names[i] = c.Name
}
return strings.Join(names, ", ")
}
// PrefixStringSlice with the given str.
func PrefixStringSlice(str string, strs []string) []string {
ret := make([]string, len(strs))
// InsertParamFlags takes a []Column and returns a comma seperated
// list of parameter flags for the insert statement template.
func InsertParamFlags(columns []dbdrivers.Column) string {
params := make([]string, len(columns))
for i := range columns {
params[i] = fmt.Sprintf("$%d", i+1)
}
return strings.Join(params, ", ")
}
// InsertParamVariables takes a prefix and a []Columns and returns a
// comma seperated list of parameter variable names for the insert statement.
// For example: prefix("o."), column("name_id") -> "o.NameID, ..."
func InsertParamVariables(prefix string, columns []dbdrivers.Column) string {
names := make([]string, len(columns))
for i, c := range columns {
names[i] = prefix + TitleCase(c.Name)
for i, s := range strs {
ret[i] = fmt.Sprintf("%s%s", str, s)
}
return strings.Join(names, ", ")
}
// SelectParamNames takes a []Column and returns a comma seperated
// list of parameter names with for the select statement template.
// It also uses the table name to generate the "AS" part of the statement, for
// example: var_name AS table_name_var_name, ...
func SelectParamNames(tableName string, columns []dbdrivers.Column) string {
selects := make([]string, len(columns))
for i, c := range columns {
selects[i] = fmt.Sprintf("%s AS %s", c.Name, MakeDBName(tableName, c.Name))
}
return strings.Join(selects, ", ")
}
// ScanParamNames takes a []Column and returns a comma seperated
// list of parameter names for use in a db.Scan() call.
func ScanParamNames(object string, columns []dbdrivers.Column) string {
scans := make([]string, len(columns))
for i, c := range columns {
scans[i] = fmt.Sprintf("&%s.%s", object, TitleCase(c.Name))
}
return strings.Join(scans, ", ")
}
// HasPrimaryKey returns true if one of the columns passed in is a primary key
func HasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool {
if pKey == nil || len(pKey.Columns) == 0 {
return false
}
return true
return ret
}
// PrimaryKeyFuncSig generates the function signature parameters.
// example: id int64, thingName string
func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string {
var output []string
ret := make([]string, len(pkeyCols))
for _, pk := range pkeyCols {
for i, pk := range pkeyCols {
for _, c := range cols {
if pk == c.Name {
output = append(output, fmt.Sprintf("%s %s", CamelCase(pk), c.Type))
break
if pk != c.Name {
continue
}
ret[i] = fmt.Sprintf("%s %s", CamelCase(pk), c.Type)
}
}
return strings.Join(output, ", ")
return strings.Join(ret, ", ")
}
// GenerateParamFlags generates the SQL statement parameter flags
@ -273,26 +144,16 @@ func GenerateParamFlags(colCount int, startAt int) string {
// WherePrimaryKey returns the where clause using start as the $ flag index
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
func WherePrimaryKey(pkeyCols []string, start int) string {
var output string
// 0 is not a valid start number
if start == 0 {
start = 1
panic("0 is not a valid start number for wherePrimaryKey")
}
cols := make([]string, len(pkeyCols))
copy(cols, pkeyCols)
for i, c := range cols {
output = fmt.Sprintf("%s%s=$%d", output, c, start)
start++
if i < len(cols)-1 {
output = fmt.Sprintf("%s AND ", output)
}
for i, c := range pkeyCols {
cols[i] = fmt.Sprintf("%s=$%d", c, start+i)
}
return output
return strings.Join(cols, " AND ")
}
// AutoIncPrimaryKey returns the auto-increment primary key column name or an
@ -320,8 +181,8 @@ func AutoIncPrimaryKey(cols []dbdrivers.Column, pkey *dbdrivers.PrimaryKey) stri
return ""
}
// ColumnsToStrings changes the columns into a list of column names
func ColumnsToStrings(cols []dbdrivers.Column) []string {
// ColumnNames of the columns.
func ColumnNames(cols []dbdrivers.Column) []string {
names := make([]string, len(cols))
for i, c := range cols {
names[i] = c.Name
@ -330,37 +191,9 @@ func ColumnsToStrings(cols []dbdrivers.Column) []string {
return names
}
// CommaList returns a comma seperated list: "col1", "col2", "col3"
func CommaList(cols []string) string {
return fmt.Sprintf(`"%s"`, strings.Join(cols, `", "`))
}
// ParamsPrimaryKey returns the parameters for the sql statement $ flags
// For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2"
func ParamsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string {
names := make([]string, 0, len(columns))
for _, c := range columns {
var n string
if shouldTitleCase {
n = prefix + TitleCase(c)
} else {
n = prefix + c
}
names = append(names, n)
}
return strings.Join(names, ", ")
}
// PrimaryKeyFlagIndex generates the primary key column flag number for the query params
func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
return len(regularCols) - len(pkeyCols) + 1
}
// SupportsResultObject returns whether the database driver supports the
// sql.Results interface, i.e. LastReturnId and RowsAffected
func SupportsResultObject(driverName string) bool {
// DriverUsesLastInsertID returns whether the database driver supports the
// sql.Result interface.
func DriverUsesLastInsertID(driverName string) bool {
switch driverName {
case "postgres":
return false
@ -395,16 +228,6 @@ func FilterColumnsByAutoIncrement(columns []dbdrivers.Column) string {
return strings.Join(cols, `,`)
}
// AddID to the end of the string
func AddID(str string) string {
return str + "_id"
}
// RemoveID from the end of the string
func RemoveID(str string) string {
return strings.TrimSuffix(str, "_id")
}
// Substring returns a substring of str starting at index start and going
// to end-1.
func Substring(start, end int, str string) string {

View file

@ -1,6 +1,7 @@
package strmangle
import (
"strings"
"testing"
"github.com/nullbio/sqlboiler/dbdrivers"
@ -11,82 +12,6 @@ var testColumns = []dbdrivers.Column{
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
}
func TestCommaList(t *testing.T) {
t.Parallel()
cols := []string{
"test1",
}
x := CommaList(cols)
if x != `"test1"` {
t.Errorf(`Expected "test1" - got %s`, x)
}
cols = append(cols, "test2")
x = CommaList(cols)
if x != `"test1", "test2"` {
t.Errorf(`Expected "test1", "test2" - got %s`, x)
}
cols = append(cols, "test3")
x = CommaList(cols)
if x != `"test1", "test2", "test3"` {
t.Errorf(`Expected "test1", "test2", "test3" - got %s`, x)
}
}
func TestTitleCaseCommaList(t *testing.T) {
t.Parallel()
cols := []string{
"test_id",
"test_thing",
"test_stuff_thing",
"test",
}
x := TitleCaseCommaList("", cols)
expected := `TestID, TestThing, TestStuffThing, Test`
if x != expected {
t.Errorf("Expected %s, got %s", expected, x)
}
x = TitleCaseCommaList("o.", cols)
expected = `o.TestID, o.TestThing, o.TestStuffThing, o.Test`
if x != expected {
t.Errorf("Expected %s, got %s", expected, x)
}
}
func TestCamelCaseCommaList(t *testing.T) {
t.Parallel()
cols := []string{
"test_id",
"test_thing",
"test_stuff_thing",
"test",
}
x := CamelCaseCommaList("", cols)
expected := `testID, testThing, testStuffThing, test`
if x != expected {
t.Errorf("Expected %s, got %s", expected, x)
}
x = CamelCaseCommaList("o.", cols)
expected = `o.testID, o.testThing, o.testStuffThing, o.test`
if x != expected {
t.Errorf("Expected %s, got %s", expected, x)
}
}
func TestAutoIncPrimaryKey(t *testing.T) {
t.Parallel()
@ -95,6 +20,11 @@ func TestAutoIncPrimaryKey(t *testing.T) {
Pkey *dbdrivers.PrimaryKey
Columns []dbdrivers.Column
}{
"nillcase": {
Expect: "",
Pkey: nil,
Columns: nil,
},
"easycase": {
Expect: "one",
Pkey: &dbdrivers.PrimaryKey{
@ -180,6 +110,32 @@ func TestAutoIncPrimaryKey(t *testing.T) {
}
}
func TestColumnNames(t *testing.T) {
t.Parallel()
cols := []dbdrivers.Column{
dbdrivers.Column{Name: "one"},
dbdrivers.Column{Name: "two"},
dbdrivers.Column{Name: "three"},
}
out := strings.Join(ColumnNames(cols), " ")
if out != "one two three" {
t.Error("output was wrong:", out)
}
}
func TestDriverUsesResults(t *testing.T) {
t.Parallel()
if DriverUsesLastInsertID("postgres") {
t.Error("postgres does not support LastInsertId")
}
if !DriverUsesLastInsertID("mysql") {
t.Error("postgres does support LastInsertId")
}
}
func TestGenerateParamFlags(t *testing.T) {
t.Parallel()
@ -266,6 +222,15 @@ func TestCamelCase(t *testing.T) {
}
}
func TestStringMap(t *testing.T) {
t.Parallel()
mapped := StringMap(strings.ToLower, []string{"HELLO", "WORLD"})
if got := strings.Join(mapped, " "); got != "hello world" {
t.Errorf("mapped was wrong: %q", got)
}
}
func TestMakeDBName(t *testing.T) {
t.Parallel()
@ -274,153 +239,53 @@ func TestMakeDBName(t *testing.T) {
}
}
func TestUpdateParamNames(t *testing.T) {
func TestHasElement(t *testing.T) {
t.Parallel()
var testCols = []dbdrivers.Column{
{Name: "id", Type: "int", IsNullable: false},
{Name: "friend_column", Type: "int", IsNullable: false},
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
elements := []string{"one", "two"}
if got := HasElement("one", elements); !got {
t.Error("should have found element key")
}
out := UpdateParamNames(testCols, []string{"id"})
if out != "friend_column=$1,enemy_column_thing=$2" {
t.Error("Wrong output:", out)
if got := HasElement("three", elements); got {
t.Error("should not have found element key")
}
}
func TestUpdateParamVariables(t *testing.T) {
func TestPrefixStringSlice(t *testing.T) {
t.Parallel()
var testCols = []dbdrivers.Column{
{Name: "id", Type: "int", IsNullable: false},
{Name: "friend_column", Type: "int", IsNullable: false},
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
}
out := UpdateParamVariables("o.", testCols, []string{"id"})
if out != "o.FriendColumn, o.EnemyColumnThing" {
t.Error("Wrong output:", out)
slice := PrefixStringSlice("o.", []string{"one", "two"})
if got := strings.Join(slice, " "); got != "o.one o.two" {
t.Error("wrong output:", got)
}
}
func TestInsertParamNames(t *testing.T) {
func TestPrimaryKeyFuncSig(t *testing.T) {
t.Parallel()
out := InsertParamNames(testColumns)
if out != "friend_column, enemy_column_thing" {
t.Error("Wrong output:", out)
}
}
func TestInsertParamFlags(t *testing.T) {
t.Parallel()
out := InsertParamFlags(testColumns)
if out != "$1, $2" {
t.Error("Wrong output:", out)
}
}
func TestInsertParamVariables(t *testing.T) {
t.Parallel()
out := InsertParamVariables("o.", testColumns)
if out != "o.FriendColumn, o.EnemyColumnThing" {
t.Error("Wrong output:", out)
}
}
func TestSelectParamFlags(t *testing.T) {
t.Parallel()
out := SelectParamNames("table", testColumns)
if out != "friend_column AS table_friend_column, enemy_column_thing AS table_enemy_column_thing" {
t.Error("Wrong output:", out)
}
}
func TestScanParams(t *testing.T) {
t.Parallel()
out := ScanParamNames("object", testColumns)
if out != "&object.FriendColumn, &object.EnemyColumnThing" {
t.Error("Wrong output:", out)
}
}
func TestHasPrimaryKey(t *testing.T) {
t.Parallel()
var pkey *dbdrivers.PrimaryKey
if HasPrimaryKey(pkey) {
t.Errorf("1) Expected false, got true")
}
pkey = &dbdrivers.PrimaryKey{}
if HasPrimaryKey(pkey) {
t.Errorf("2) Expected false, got true")
}
pkey.Columns = append(pkey.Columns, "test")
if !HasPrimaryKey(pkey) {
t.Errorf("3) Expected true, got false")
}
}
func TestParamsPrimaryKey(t *testing.T) {
t.Parallel()
tests := []struct {
Pkey dbdrivers.PrimaryKey
Prefix string
Should string
}{
cols := []dbdrivers.Column{
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
Prefix: "o.", Should: "o.ColOne",
Name: "one",
Type: "int64",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
Prefix: "o.", Should: "o.ColOne, o.ColTwo",
Name: "two",
Type: "string",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
Prefix: "o.", Should: "o.ColOne, o.ColTwo, o.ColThree",
Name: "three",
Type: "string",
},
}
for i, test := range tests {
r := ParamsPrimaryKey(test.Prefix, test.Pkey.Columns, true)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}
sig := PrimaryKeyFuncSig(cols, []string{"one"})
if sig != "one int64" {
t.Error("wrong signature:", sig)
}
tests2 := []struct {
Pkey dbdrivers.PrimaryKey
Prefix string
Should string
}{
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one"}},
Prefix: "o.", Should: "o.col_one",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two"}},
Prefix: "o.", Should: "o.col_one, o.col_two",
},
{
Pkey: dbdrivers.PrimaryKey{Columns: []string{"col_one", "col_two", "col_three"}},
Prefix: "o.", Should: "o.col_one, o.col_two, o.col_three",
},
}
for i, test := range tests2 {
r := ParamsPrimaryKey(test.Prefix, test.Pkey.Columns, false)
if r != test.Should {
t.Errorf("(%d) want: %s, got: %s\nTest: %#v", i, test.Should, r, test)
}
sig = PrimaryKeyFuncSig(cols, []string{"one", "three"})
if sig != "one int64, three string" {
t.Error("wrong signature:", sig)
}
}
@ -445,6 +310,18 @@ func TestWherePrimaryKey(t *testing.T) {
}
}
func TestWherePrimaryKeyPanic(t *testing.T) {
t.Parallel()
defer func() {
if recover() == nil {
t.Error("did not panic")
}
}()
WherePrimaryKey(nil, 0)
}
func TestFilterColumnsByDefault(t *testing.T) {
t.Parallel()
@ -515,3 +392,22 @@ func TestFilterColumnsByAutoIncrement(t *testing.T) {
t.Errorf("Invalid result: %s", res)
}
}
func TestSubstring(t *testing.T) {
t.Parallel()
str := "hello"
if got := Substring(0, 5, str); got != "hello" {
t.Errorf("substring was wrong: %q", got)
}
if got := Substring(1, 4, str); got != "ell" {
t.Errorf("substring was wrong: %q", got)
}
if got := Substring(2, 3, str); got != "l" {
t.Errorf("substring was wrong: %q", got)
}
if got := Substring(5, 5, str); got != "" {
t.Errorf("substring was wrong: %q", got)
}
}

View file

@ -1,26 +0,0 @@
package strmangle
import (
"fmt"
"strings"
"github.com/nullbio/sqlboiler/dbdrivers"
)
// RandDBStruct does nothing yet
// TODO(nullbio): What is this?
func RandDBStruct(varName string, table dbdrivers.Table) string {
return ""
}
// RandDBStructSlice randomizes a struct?
// TODO(nullbio): What is this?
func RandDBStructSlice(varName string, num int, table dbdrivers.Table) string {
var structs []string
for i := 0; i < num; i++ {
structs = append(structs, RandDBStruct(varName, table))
}
innerStructs := strings.Join(structs, ",")
return fmt.Sprintf("%s := %s{%s}", varName, TitleCasePlural(table.Name), innerStructs)
}

View file

@ -1 +0,0 @@
package strmangle