package strmangle import ( "fmt" "regexp" "strings" "github.com/jinzhu/inflection" "github.com/pobri19/sqlboiler/dbdrivers" ) var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`) // Plural converts singular words to plural words (eg: person to people) func Plural(name string) string { splits := strings.Split(name, "_") splits[len(splits)-1] = inflection.Plural(splits[len(splits)-1]) return strings.Join(splits, "_") } // Singular converts plural words to singular words (eg: people to person) func Singular(name string) string { splits := strings.Split(name, "_") splits[len(splits)-1] = inflection.Singular(splits[len(splits)-1]) return strings.Join(splits, "_") } // TitleCase changes a snake-case variable name // into a go styled object variable name of "ColumnName". // titleCase also fully uppercases "ID" components of names, for example // "column_name_id" to "ColumnNameID". func TitleCase(name string) string { splits := strings.Split(name, "_") for i, split := range splits { if split == "id" { splits[i] = "ID" continue } splits[i] = strings.Title(split) } return strings.Join(splits, "") } // TitleCaseSingular changes a snake-case variable name // to a go styled object variable name of "ColumnName". // titleCaseSingular also converts the last word in the // variable name to a singularized version of itself. func TitleCaseSingular(name string) string { return TitleCase(Singular(name)) } // TitleCasePlural changes a snake-case variable name // to a go styled object variable name of "ColumnName". // titleCasePlural also converts the last word in the // variable name to a pluralized version of itself. func TitleCasePlural(name string) string { return TitleCase(Plural(name)) } // CamelCase takes a variable name in the format of "var_name" and converts // it into a go styled variable name of "varName". // camelCase also fully uppercases "ID" components of names, for example // "var_name_id" to "varNameID". func CamelCase(name string) string { splits := strings.Split(name, "_") for i, split := range splits { if split == "id" && i > 0 { splits[i] = "ID" continue } if i == 0 { continue } splits[i] = strings.Title(split) } return strings.Join(splits, "") } // CamelCaseSingular takes a variable name in the format of "var_name" and converts // it into a go styled variable name of "varName". // CamelCaseSingular also converts the last word in the // variable name to a singularized version of itself. func CamelCaseSingular(name string) string { return CamelCase(Singular(name)) } // CamelCasePlural takes a variable name in the format of "var_name" and converts // it into a go styled variable name of "varName". // CamelCasePlural also converts the last word in the // variable name to a pluralized version of itself. func CamelCasePlural(name string) string { return CamelCase(Plural(name)) } // CamelCaseCommaList generates a list of comma seperated camel cased column names // example: thingName, stuffName, etc func CamelCaseCommaList(pkeyColumns []string) string { var output []string for _, c := range pkeyColumns { output = append(output, CamelCase(c)) } return strings.Join(output, ", ") } // MakeDBName takes a table name in the format of "table_name" and a // column name in the format of "column_name" and returns a name used in the // `db:""` component of an object in the format of "table_name_column_name" func MakeDBName(tableName, colName string) string { return tableName + "_" + colName } // UpdateParamNames takes a []Column and returns a comma seperated // list of parameter names for the update statement template SET clause. // eg: col1=$1,col2=$2,col3=$3 // Note: updateParamNames will exclude the PRIMARY KEY column. func UpdateParamNames(columns []dbdrivers.Column, pkeyColumns []string) string { names := make([]string, 0, len(columns)) counter := 0 for _, c := range columns { if IsPrimaryKey(c.Name, pkeyColumns) { continue } counter++ names = append(names, fmt.Sprintf("%s=$%d", c.Name, counter)) } return strings.Join(names, ",") } // UpdateParamVariables takes a prefix and a []Columns and returns a // comma seperated list of parameter variable names for the update statement. // eg: prefix("o."), column("name_id") -> "o.NameID, ..." // Note: updateParamVariables will exclude the PRIMARY KEY column. func UpdateParamVariables(prefix string, columns []dbdrivers.Column, pkeyColumns []string) string { names := make([]string, 0, len(columns)) for _, c := range columns { if IsPrimaryKey(c.Name, pkeyColumns) { continue } n := prefix + TitleCase(c.Name) names = append(names, n) } return strings.Join(names, ", ") } // IsPrimaryKey checks if the column is found in the primary key columns func IsPrimaryKey(col string, pkeyCols []string) bool { for _, pkey := range pkeyCols { if pkey == col { return true } } return false } // InsertParamNames takes a []Column and returns a comma seperated // list of parameter names for the insert statement template. func InsertParamNames(columns []dbdrivers.Column) string { names := make([]string, 0, len(columns)) for _, c := range columns { names = append(names, c.Name) } return strings.Join(names, ", ") } // InsertParamFlags takes a []Column and returns a comma seperated // list of parameter flags for the insert statement template. func InsertParamFlags(columns []dbdrivers.Column) string { params := make([]string, 0, len(columns)) for i := range columns { params = append(params, fmt.Sprintf("$%d", i+1)) } return strings.Join(params, ", ") } // InsertParamVariables takes a prefix and a []Columns and returns a // comma seperated list of parameter variable names for the insert statement. // For example: prefix("o."), column("name_id") -> "o.NameID, ..." func InsertParamVariables(prefix string, columns []dbdrivers.Column) string { names := make([]string, 0, len(columns)) for _, c := range columns { n := prefix + TitleCase(c.Name) names = append(names, n) } return strings.Join(names, ", ") } // SelectParamNames takes a []Column and returns a comma seperated // list of parameter names with for the select statement template. // It also uses the table name to generate the "AS" part of the statement, for // example: var_name AS table_name_var_name, ... func SelectParamNames(tableName string, columns []dbdrivers.Column) string { selects := make([]string, 0, len(columns)) for _, c := range columns { statement := fmt.Sprintf("%s AS %s", c.Name, MakeDBName(tableName, c.Name)) selects = append(selects, statement) } return strings.Join(selects, ", ") } // ScanParamNames takes a []Column and returns a comma seperated // list of parameter names for use in a db.Scan() call. func ScanParamNames(object string, columns []dbdrivers.Column) string { scans := make([]string, 0, len(columns)) for _, c := range columns { statement := fmt.Sprintf("&%s.%s", object, TitleCase(c.Name)) scans = append(scans, statement) } return strings.Join(scans, ", ") } // HasPrimaryKey returns true if one of the columns passed in is a primary key func HasPrimaryKey(pKey *dbdrivers.PrimaryKey) bool { if pKey == nil || len(pKey.Columns) == 0 { return false } return true } // PrimaryKeyFuncSig generates the function signature parameters. // example: id int64, thingName string func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string { var output []string for _, pk := range pkeyCols { for _, c := range cols { if pk == c.Name { output = append(output, fmt.Sprintf("%s %s", CamelCase(pk), c.Type)) break } } } return strings.Join(output, ", ") } // GenerateParamFlags generates the SQL statement parameter flags // For example, $1,$2,$3 etc. It will start counting at startAt. func GenerateParamFlags(colCount int, startAt int) string { cols := make([]string, 0, colCount) for i := startAt; i < colCount+startAt; i++ { cols = append(cols, fmt.Sprintf("$%d", i)) } return strings.Join(cols, ",") } // WherePrimaryKey returns the where clause using start as the $ flag index // For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3" func WherePrimaryKey(pkeyCols []string, start int) string { var output string cols := make([]string, len(pkeyCols)) copy(cols, pkeyCols) for i, c := range cols { output = fmt.Sprintf("%s%s=$%d", output, c, start) start++ if i < len(cols)-1 { output = fmt.Sprintf("%s AND ", output) } } return output } // PrimaryKeyStrList returns a list of primary key column names in strings // For example: "col1", "col2", "col3" func PrimaryKeyStrList(pkeyCols []string) string { cols := make([]string, len(pkeyCols)) copy(cols, pkeyCols) for i, c := range cols { cols[i] = fmt.Sprintf(`"%s"`, c) } return strings.Join(cols, ", ") } // AutoIncPrimKey returns the auto-increment primary key column name or an empty string func AutoIncPrimaryKey(cols []dbdrivers.Column, pkey *dbdrivers.PrimaryKey) string { if pkey == nil { return "" } for _, c := range cols { if rgxAutoIncColumn.MatchString(c.Default) && c.IsNullable == false && c.Type == "int64" { for _, p := range pkey.Columns { if c.Name == p { return p } } } } return "" } // CommaList returns a comma seperated list: "col1", "col2", "col3" func CommaList(cols []string) string { return fmt.Sprintf(`"%s"`, strings.Join(cols, `", "`)) } // ParamsPrimaryKey returns the parameters for the sql statement $ flags // For example, if prefix was "o.", and titleCase was true: "o.ColumnName1, o.ColumnName2" func ParamsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) string { names := make([]string, 0, len(columns)) for _, c := range columns { var n string if shouldTitleCase { n = prefix + TitleCase(c) } else { n = prefix + c } names = append(names, n) } return strings.Join(names, ", ") } // PrimaryKeyFlagIndex generates the primary key column flag number for the query params func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int { return len(regularCols) - len(pkeyCols) + 1 } // SupportsResult returns whether the database driver supports the sql.Results // interface, i.e. LastReturnId and RowsAffected func SupportsResultObject(driverName string) bool { switch driverName { case "postgres": return false default: return true } } // FilterColumnsByDefault generates the list of columns that have default values func FilterColumnsByDefault(columns []dbdrivers.Column, defaults bool) string { var cols []string for _, c := range columns { if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) { cols = append(cols, fmt.Sprintf(`"%s"`, c.Name)) } } return strings.Join(cols, `,`) } // FilterColumnsByAutoIncrement generates the list of auto increment columns func FilterColumnsByAutoIncrement(columns []dbdrivers.Column) string { var cols []string for _, c := range columns { if rgxAutoIncColumn.MatchString(c.Default) { cols = append(cols, fmt.Sprintf(`"%s"`, c.Name)) } } return strings.Join(cols, `,`) }