Finished insert template
* Removed functions in helpers not being used
This commit is contained in:
parent
a957bc3836
commit
f059bdebf4
10 changed files with 576 additions and 98 deletions
16
boil/bind.go
16
boil/bind.go
|
@ -49,6 +49,22 @@ func checkType(obj interface{}) (reflect.Type, bool, error) {
|
||||||
return typ, isSlice, nil
|
return typ, isSlice, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetStructValues returns the values (as interface) of the matching columns in obj
|
||||||
|
func GetStructValues(obj interface{}, columns ...string) []interface{} {
|
||||||
|
ret := make([]interface{}, len(columns))
|
||||||
|
val := reflect.ValueOf(obj)
|
||||||
|
if val.Kind() == reflect.Ptr {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, c := range columns {
|
||||||
|
field := val.FieldByName(strmangle.TitleCase(c))
|
||||||
|
ret[i] = field.Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return ret
|
||||||
|
}
|
||||||
|
|
||||||
// GetStructPointers returns a slice of pointers to the matching columns in obj
|
// GetStructPointers returns a slice of pointers to the matching columns in obj
|
||||||
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
|
func GetStructPointers(obj interface{}, columns ...string) []interface{} {
|
||||||
val := reflect.ValueOf(obj).Elem()
|
val := reflect.ValueOf(obj).Elem()
|
||||||
|
|
|
@ -1,6 +1,54 @@
|
||||||
package boil
|
package boil
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/guregu/null"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetStructValues(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
timeThing := time.Now()
|
||||||
|
o := struct {
|
||||||
|
TitleThing string
|
||||||
|
Name string
|
||||||
|
ID int
|
||||||
|
Stuff int
|
||||||
|
Things int
|
||||||
|
Time time.Time
|
||||||
|
NullBool null.Bool
|
||||||
|
}{
|
||||||
|
TitleThing: "patrick",
|
||||||
|
Stuff: 10,
|
||||||
|
Things: 0,
|
||||||
|
Time: timeThing,
|
||||||
|
NullBool: null.NewBool(true, false),
|
||||||
|
}
|
||||||
|
|
||||||
|
vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool")
|
||||||
|
if vals[0].(string) != "patrick" {
|
||||||
|
t.Errorf("Want test, got %s", vals[0])
|
||||||
|
}
|
||||||
|
if vals[1].(string) != "" {
|
||||||
|
t.Errorf("Want empty string, got %s", vals[1])
|
||||||
|
}
|
||||||
|
if vals[2].(int) != 0 {
|
||||||
|
t.Errorf("Want 0, got %d", vals[2])
|
||||||
|
}
|
||||||
|
if vals[3].(int) != 10 {
|
||||||
|
t.Errorf("Want 10, got %d", vals[3])
|
||||||
|
}
|
||||||
|
if vals[4].(int) != 0 {
|
||||||
|
t.Errorf("Want 0, got %d", vals[4])
|
||||||
|
}
|
||||||
|
if !vals[5].(time.Time).Equal(timeThing) {
|
||||||
|
t.Errorf("Want %s, got %s", o.Time, vals[5])
|
||||||
|
}
|
||||||
|
if !vals[6].(null.Bool).IsZero() {
|
||||||
|
t.Errorf("Want %v, got %v", o.NullBool, vals[6])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestGetStructPointers(t *testing.T) {
|
func TestGetStructPointers(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
|
@ -7,8 +7,79 @@ import (
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
|
|
||||||
|
"github.com/pobri19/sqlboiler/strmangle"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetComplement subtracts the elements in b from a
|
||||||
|
func SetComplement(a []string, b []string) []string {
|
||||||
|
c := make([]string, 0, len(a))
|
||||||
|
|
||||||
|
for _, aVal := range a {
|
||||||
|
found := false
|
||||||
|
for _, bVal := range b {
|
||||||
|
if aVal == bVal {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
c = append(c, aVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIntersect returns the elements that are common in a and b
|
||||||
|
func SetIntersect(a []string, b []string) []string {
|
||||||
|
c := make([]string, 0, len(a))
|
||||||
|
|
||||||
|
for _, aVal := range a {
|
||||||
|
found := false
|
||||||
|
for _, bVal := range b {
|
||||||
|
if aVal == bVal {
|
||||||
|
found = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if found {
|
||||||
|
c = append(c, aVal)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// NonZeroDefaultSet returns the fields included in the
|
||||||
|
// defaults slice that are non zero values
|
||||||
|
func NonZeroDefaultSet(defaults []string, obj interface{}) []string {
|
||||||
|
c := make([]string, 0, len(defaults))
|
||||||
|
|
||||||
|
val := reflect.ValueOf(obj)
|
||||||
|
|
||||||
|
for _, d := range defaults {
|
||||||
|
fieldName := strmangle.TitleCase(d)
|
||||||
|
field := val.FieldByName(fieldName)
|
||||||
|
if !field.IsValid() {
|
||||||
|
panic(fmt.Sprintf("Could not find field name %s in type %T", fieldName, obj))
|
||||||
|
}
|
||||||
|
|
||||||
|
zero := reflect.Zero(field.Type())
|
||||||
|
if !reflect.DeepEqual(zero.Interface(), field.Interface()) {
|
||||||
|
c = append(c, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateParamFlags generates the SQL statement parameter flags
|
||||||
|
// For example, $1,$2,$3 etc. It will start counting at startAt.
|
||||||
|
func GenerateParamFlags(colCount int, startAt int) string {
|
||||||
|
return strmangle.GenerateParamFlags(colCount, startAt)
|
||||||
|
}
|
||||||
|
|
||||||
// WherePrimaryKeyIn generates a "in" string for where queries
|
// WherePrimaryKeyIn generates a "in" string for where queries
|
||||||
// For example: (col1, col2) IN (($1, $2), ($3, $4))
|
// For example: (col1, col2) IN (($1, $2), ($3, $4))
|
||||||
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
func WherePrimaryKeyIn(numRows int, keyNames ...string) string {
|
||||||
|
@ -81,15 +152,13 @@ func SelectNames(results interface{}) string {
|
||||||
|
|
||||||
// WhereClause returns the where clause for an sql statement
|
// WhereClause returns the where clause for an sql statement
|
||||||
// eg: col1=$1 AND col2=$2 AND col3=$3
|
// eg: col1=$1 AND col2=$2 AND col3=$3
|
||||||
func WhereClause(columns map[string]interface{}) string {
|
func WhereClause(columns []string) string {
|
||||||
names := make([]string, 0, len(columns))
|
names := make([]string, 0, len(columns))
|
||||||
|
|
||||||
for c := range columns {
|
for _, c := range columns {
|
||||||
names = append(names, c)
|
names = append(names, c)
|
||||||
}
|
}
|
||||||
|
|
||||||
sort.Strings(names)
|
|
||||||
|
|
||||||
for i, c := range names {
|
for i, c := range names {
|
||||||
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
names[i] = fmt.Sprintf("%s=$%d", c, i+1)
|
||||||
}
|
}
|
||||||
|
@ -115,24 +184,6 @@ func Update(columns map[string]interface{}) string {
|
||||||
return strings.Join(names, ",")
|
return strings.Join(names, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
// WhereParams returns a list of sql parameter values for the query
|
|
||||||
func WhereParams(columns map[string]interface{}) []interface{} {
|
|
||||||
names := make([]string, 0, len(columns))
|
|
||||||
results := make([]interface{}, 0, len(columns))
|
|
||||||
|
|
||||||
for c := range columns {
|
|
||||||
names = append(names, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Strings(names)
|
|
||||||
|
|
||||||
for _, c := range names {
|
|
||||||
results = append(results, columns[c])
|
|
||||||
}
|
|
||||||
|
|
||||||
return results
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetParamNames takes a slice of columns and returns a comma seperated
|
// SetParamNames takes a slice of columns and returns a comma seperated
|
||||||
// list of parameter names for a template statement SET clause.
|
// list of parameter names for a template statement SET clause.
|
||||||
// eg: col1=$1,col2=$2,col3=$3
|
// eg: col1=$1,col2=$2,col3=$3
|
||||||
|
|
|
@ -1,8 +1,11 @@
|
||||||
package boil
|
package boil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/guregu/null"
|
||||||
)
|
)
|
||||||
|
|
||||||
type testObj struct {
|
type testObj struct {
|
||||||
|
@ -11,6 +14,134 @@ type testObj struct {
|
||||||
HeadSize int
|
HeadSize int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestSetComplement(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
A []string
|
||||||
|
B []string
|
||||||
|
C []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2", "thing3"},
|
||||||
|
[]string{"thing2", "otherthing", "stuff"},
|
||||||
|
[]string{"thing1", "thing3"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
c := SetComplement(test.A, test.B)
|
||||||
|
if !reflect.DeepEqual(test.C, c) {
|
||||||
|
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetIntersect(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
A []string
|
||||||
|
B []string
|
||||||
|
C []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2", "thing3"},
|
||||||
|
[]string{"thing2", "otherthing", "stuff"},
|
||||||
|
[]string{"thing2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{},
|
||||||
|
[]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
[]string{"thing1", "thing2"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
c := SetIntersect(test.A, test.B)
|
||||||
|
if !reflect.DeepEqual(test.C, c) {
|
||||||
|
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.C, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNonZeroDefaultSet(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
type Anything struct {
|
||||||
|
ID int
|
||||||
|
Name string
|
||||||
|
CreatedAt *time.Time
|
||||||
|
UpdatedAt null.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
Defaults []string
|
||||||
|
Obj interface{}
|
||||||
|
Ret []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
[]string{"id"},
|
||||||
|
Anything{Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||||
|
[]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"id"},
|
||||||
|
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||||
|
[]string{"id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{},
|
||||||
|
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||||
|
[]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"id", "created_at", "updated_at"},
|
||||||
|
Anything{ID: 5, Name: "hi", CreatedAt: nil, UpdatedAt: null.Time{Valid: false}},
|
||||||
|
[]string{"id"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
[]string{"id", "created_at", "updated_at"},
|
||||||
|
Anything{ID: 5, Name: "hi", CreatedAt: &now, UpdatedAt: null.Time{Valid: true, Time: time.Now()}},
|
||||||
|
[]string{"id", "created_at", "updated_at"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
z := NonZeroDefaultSet(test.Defaults, test.Obj)
|
||||||
|
if !reflect.DeepEqual(test.Ret, z) {
|
||||||
|
t.Errorf("[%d] mismatch:\nWant: %#v\nGot: %#v", i, test.Ret, z)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestWherePrimaryKeyIn(t *testing.T) {
|
func TestWherePrimaryKeyIn(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -105,34 +236,15 @@ func TestSelectNames(t *testing.T) {
|
||||||
func TestWhereClause(t *testing.T) {
|
func TestWhereClause(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
columns := map[string]interface{}{
|
columns := []string{
|
||||||
"name": "bob",
|
"id",
|
||||||
"id": 5,
|
"name",
|
||||||
"date": time.Now(),
|
"date",
|
||||||
}
|
}
|
||||||
|
|
||||||
result := WhereClause(columns)
|
result := WhereClause(columns)
|
||||||
|
|
||||||
if result != `date=$1 AND id=$2 AND name=$3` {
|
if result != `id=$1 AND name=$2 AND date=$3` {
|
||||||
t.Error("Result was wrong, got:", result)
|
t.Error("Result was wrong, got:", result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestWhereParams(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
columns := map[string]interface{}{
|
|
||||||
"name": "bob",
|
|
||||||
"id": 5,
|
|
||||||
}
|
|
||||||
|
|
||||||
result := WhereParams(columns)
|
|
||||||
|
|
||||||
if result[0].(int) != 5 {
|
|
||||||
t.Error("Result[0] was wrong, got:", result[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
if result[1].(string) != "bob" {
|
|
||||||
t.Error("Result[1] was wrong, got:", result[1])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -37,6 +37,7 @@ var sqlBoilerImports = imports{
|
||||||
standard: importList{
|
standard: importList{
|
||||||
`"errors"`,
|
`"errors"`,
|
||||||
`"fmt"`,
|
`"fmt"`,
|
||||||
|
`"strings"`,
|
||||||
},
|
},
|
||||||
thirdparty: importList{
|
thirdparty: importList{
|
||||||
`"github.com/pobri19/sqlboiler/boil"`,
|
`"github.com/pobri19/sqlboiler/boil"`,
|
||||||
|
@ -108,6 +109,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{
|
||||||
"updateParamNames": strmangle.UpdateParamNames,
|
"updateParamNames": strmangle.UpdateParamNames,
|
||||||
"updateParamVariables": strmangle.UpdateParamVariables,
|
"updateParamVariables": strmangle.UpdateParamVariables,
|
||||||
"primaryKeyStrList": strmangle.PrimaryKeyStrList,
|
"primaryKeyStrList": strmangle.PrimaryKeyStrList,
|
||||||
|
"supportsResultObject": strmangle.SupportsResultObject,
|
||||||
|
"filterColumnsByDefault": strmangle.FilterColumnsByDefault,
|
||||||
|
"autoIncPrimaryKey": strmangle.AutoIncPrimaryKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfigFile loads the toml config file into the cfg object
|
// LoadConfigFile loads the toml config file into the cfg object
|
||||||
|
|
|
@ -229,6 +229,26 @@ func initTables(tableName string, cmdData *CmdData) error {
|
||||||
return errors.New("No tables found in database, migrate some tables first")
|
return errors.New("No tables found in database, migrate some tables first")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := checkPKeys(cmdData.Tables); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// checkPKeys ensures every table has a primary key column
|
||||||
|
func checkPKeys(tables []dbdrivers.Table) error {
|
||||||
|
var missingPkey []string
|
||||||
|
for _, t := range tables {
|
||||||
|
if t.PKey == nil {
|
||||||
|
missingPkey = append(missingPkey, t.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(missingPkey) != 0 {
|
||||||
|
return fmt.Errorf("Cannot continue until the follow tables have PRIMARY KEY columns: %s", strings.Join(missingPkey, ", "))
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -40,8 +40,23 @@ func init() {
|
||||||
{
|
{
|
||||||
Name: "spiderman",
|
Name: "spiderman",
|
||||||
Columns: []dbdrivers.Column{
|
Columns: []dbdrivers.Column{
|
||||||
|
{Name: "id", Type: "int64", IsNullable: false},
|
||||||
|
},
|
||||||
|
PKey: &dbdrivers.PrimaryKey{
|
||||||
|
Name: "pkey_id",
|
||||||
|
Columns: []string{"id"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "spiderman_table_two",
|
||||||
|
Columns: []dbdrivers.Column{
|
||||||
|
{Name: "id", Type: "int64", IsNullable: false},
|
||||||
{Name: "patrick", Type: "string", IsNullable: false},
|
{Name: "patrick", Type: "string", IsNullable: false},
|
||||||
},
|
},
|
||||||
|
PKey: &dbdrivers.PrimaryKey{
|
||||||
|
Name: "pkey_id",
|
||||||
|
Columns: []string{"id"},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
PkgName: "patrick",
|
PkgName: "patrick",
|
||||||
|
@ -69,6 +84,10 @@ func TestTemplates(t *testing.T) {
|
||||||
t.SkipNow()
|
t.SkipNow()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := checkPKeys(cmdData.Tables); err != nil {
|
||||||
|
t.Fatalf("%s", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Initialize the templates
|
// Initialize the templates
|
||||||
var err error
|
var err error
|
||||||
cmdData.Templates, err = loadTemplates("templates")
|
cmdData.Templates, err = loadTemplates("templates")
|
||||||
|
|
|
@ -1,47 +1,86 @@
|
||||||
|
{{- if hasPrimaryKey .Table.PKey -}}
|
||||||
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
{{- $tableNameSingular := titleCaseSingular .Table.Name -}}
|
||||||
|
{{- $varNameSingular := camelCaseSingular .Table.Name -}}
|
||||||
// {{$tableNameSingular}}Insert inserts a single record.
|
// {{$tableNameSingular}}Insert inserts a single record.
|
||||||
func (o *{{$tableNameSingular}}) Insert(whitelist ... string) error {
|
func (o *{{$tableNameSingular}}) Insert(whitelist ... string) error {
|
||||||
if o == nil {
|
return o.InsertX(boil.GetDB(), whitelist...)
|
||||||
return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := o.doBeforeCreateHooks(); err != nil {
|
var {{$varNameSingular}}DefaultInsertWhitelist = []string{{"{"}}{{filterColumnsByDefault .Table.Columns false}}{{"}"}}
|
||||||
return 0, err
|
var {{$varNameSingular}}ColumnsWithDefault = []string{{"{"}}{{filterColumnsByDefault .Table.Columns true}}{{"}"}}
|
||||||
}
|
var {{$varNameSingular}}AutoIncPrimaryKey = "{{autoIncPrimaryKey .Table.Columns .Table.PKey}}"
|
||||||
|
|
||||||
var rowID int
|
|
||||||
err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := o.doAfterCreateHooks(); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return rowID, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... string) error {
|
func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... string) error {
|
||||||
if o == nil {
|
if o == nil {
|
||||||
return 0, errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
|
return errors.New("{{.PkgName}}: no {{.Table.Name}} provided for insertion")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(whitelist) == 0 {
|
||||||
|
whitelist = {{$varNameSingular}}DefaultInsertWhitelist
|
||||||
|
}
|
||||||
|
|
||||||
|
nzDefaultSet := boil.NonZeroDefaultSet({{$varNameSingular}}ColumnsWithDefault, o)
|
||||||
|
if len(nzDefaultSet) != 0 {
|
||||||
|
whitelist = append(nzDefaultSet, whitelist...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Only return the columns with default values that are not in the insert whitelist
|
||||||
|
returnColumns := boil.SetComplement({{$varNameSingular}}ColumnsWithDefault, whitelist)
|
||||||
|
|
||||||
|
var err error
|
||||||
if err := o.doBeforeCreateHooks(); err != nil {
|
if err := o.doBeforeCreateHooks(); err != nil {
|
||||||
return 0, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
var rowID int
|
ins := fmt.Sprintf(`INSERT INTO {{.Table.Name}} (%s) VALUES (%s)`, strings.Join(whitelist, ","), boil.GenerateParamFlags(len(whitelist), 1))
|
||||||
err := boil.GetDB().QueryRow(`INSERT INTO {{.Table.Name}} ({{insertParamNames .Table.Columns}}) VALUES({{insertParamFlags .Table.Columns}}) RETURNING id`, {{insertParamVariables "o." .Table.Columns}}).Scan(&rowID)
|
|
||||||
|
{{if supportsResultObject .DriverName}}
|
||||||
|
if len(returnColumns) != 0 {
|
||||||
|
result, err := exec.Exec(ins, boil.GetStructValues(o, whitelist...))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
lastId, err := result.lastInsertId()
|
||||||
|
if err != nil || lastId == 0 {
|
||||||
|
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, ","), boil.WhereClause(whitelist))
|
||||||
|
rows, err := exec.Query(sel, boil.GetStructValues(o, whitelist...))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
|
||||||
|
}
|
||||||
|
defer rows.Close()
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
ptrs := boil.GetStructPointers(o, returnColumns...)
|
||||||
|
for rows.Next() {
|
||||||
|
if err := rows.Scan(ptrs[i]); err != nil {
|
||||||
|
return fmt.Errorf("{{.PkgName}}: unable to get result of insert, scan failed for column %s index %d: %s\n\n%#v", returnColumns[i], i, err, ptrs)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
} else if {{$varNameSingular}}AutoIncPrimKey != "" {
|
||||||
|
sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s=$1`, strings.Join(returnColumns, ","), {{$varNameSingular}}AutoIncPrimaryKey, lastId)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
_, err = exec.Exec(ins, boil.GetStructValues(o, whitelist...))
|
||||||
|
}
|
||||||
|
{{else}}
|
||||||
|
if len(returnColumns) != 0 {
|
||||||
|
ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ","))
|
||||||
|
err = exec.QueryRow(ins, boil.GetStructValues(o, whitelist...)).Scan(boil.GetStructPointers(o, returnColumns...))
|
||||||
|
} else {
|
||||||
|
_, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}})
|
||||||
|
}
|
||||||
|
{{end}}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, fmt.Errorf("{{.PkgName}}: unable to insert {{.Table.Name}}: %s", err)
|
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := o.doAfterCreateHooks(); err != nil {
|
if err := o.doAfterCreateHooks(); err != nil {
|
||||||
return 0, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return rowID, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
{{- end -}}
|
||||||
|
|
|
@ -2,12 +2,15 @@ package strmangle
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/jinzhu/inflection"
|
"github.com/jinzhu/inflection"
|
||||||
"github.com/pobri19/sqlboiler/dbdrivers"
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var rgxAutoIncColumn = regexp.MustCompile(`^nextval\(.*\)`)
|
||||||
|
|
||||||
// Plural converts singular words to plural words (eg: person to people)
|
// Plural converts singular words to plural words (eg: person to people)
|
||||||
func Plural(name string) string {
|
func Plural(name string) string {
|
||||||
splits := strings.Split(name, "_")
|
splits := strings.Split(name, "_")
|
||||||
|
@ -247,6 +250,18 @@ func PrimaryKeyFuncSig(cols []dbdrivers.Column, pkeyCols []string) string {
|
||||||
return strings.Join(output, ", ")
|
return strings.Join(output, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerateParamFlags generates the SQL statement parameter flags
|
||||||
|
// For example, $1,$2,$3 etc. It will start counting at startAt.
|
||||||
|
func GenerateParamFlags(colCount int, startAt int) string {
|
||||||
|
cols := make([]string, 0, colCount)
|
||||||
|
|
||||||
|
for i := startAt; i < colCount+startAt; i++ {
|
||||||
|
cols = append(cols, fmt.Sprintf("$%d", i))
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(cols, ",")
|
||||||
|
}
|
||||||
|
|
||||||
// WherePrimaryKey returns the where clause using start as the $ flag index
|
// WherePrimaryKey returns the where clause using start as the $ flag index
|
||||||
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
|
// For example, if start was 2 output would be: "colthing=$2 AND colstuff=$3"
|
||||||
func WherePrimaryKey(pkeyCols []string, start int) string {
|
func WherePrimaryKey(pkeyCols []string, start int) string {
|
||||||
|
@ -280,6 +295,26 @@ func PrimaryKeyStrList(pkeyCols []string) string {
|
||||||
return strings.Join(cols, ", ")
|
return strings.Join(cols, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AutoIncPrimKey returns the auto-increment primary key column name or an empty string
|
||||||
|
func AutoIncPrimaryKey(cols []dbdrivers.Column, pkey *dbdrivers.PrimaryKey) string {
|
||||||
|
if pkey == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, c := range cols {
|
||||||
|
if rgxAutoIncColumn.MatchString(c.Default) &&
|
||||||
|
c.IsNullable == false && c.Type == "int64" {
|
||||||
|
for _, p := range pkey.Columns {
|
||||||
|
if c.Name == p {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
// CommaList returns a comma seperated list: "col1, col2, col3"
|
// CommaList returns a comma seperated list: "col1, col2, col3"
|
||||||
func CommaList(cols []string) string {
|
func CommaList(cols []string) string {
|
||||||
return strings.Join(cols, ", ")
|
return strings.Join(cols, ", ")
|
||||||
|
@ -307,3 +342,32 @@ func ParamsPrimaryKey(prefix string, columns []string, shouldTitleCase bool) str
|
||||||
func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
|
func PrimaryKeyFlagIndex(regularCols []dbdrivers.Column, pkeyCols []string) int {
|
||||||
return len(regularCols) - len(pkeyCols) + 1
|
return len(regularCols) - len(pkeyCols) + 1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SupportsResult returns whether the database driver supports the sql.Results
|
||||||
|
// interface, i.e. LastReturnId and RowsAffected
|
||||||
|
func SupportsResultObject(driverName string) bool {
|
||||||
|
switch driverName {
|
||||||
|
case "postgres":
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// FilterColumnsByDefault generates the list of columns that have default values
|
||||||
|
func FilterColumnsByDefault(columns []dbdrivers.Column, defaults bool) string {
|
||||||
|
var cols []string
|
||||||
|
|
||||||
|
for _, c := range columns {
|
||||||
|
if (defaults && len(c.Default) != 0) || (!defaults && len(c.Default) == 0) {
|
||||||
|
cols = append(cols, fmt.Sprintf(`"%s"`, c.Name))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.Join(cols, `,`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DEFAULT WHITELIST: The things that are not default values. The things we want to insert all the time.
|
||||||
|
// WHITELIST: The things that we will NEVER return. The things that we will ALWAYS insert.
|
||||||
|
// DEFAULTS: The things that we will return (if not in WHITELIST)
|
||||||
|
// NON-ZEROS: The things that we will return (if not in WHITELIST)
|
||||||
|
|
|
@ -11,6 +11,73 @@ var testColumns = []dbdrivers.Column{
|
||||||
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
{Name: "enemy_column_thing", Type: "string", IsNullable: true},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestAutoIncPrimaryKey(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
var pkey *dbdrivers.PrimaryKey
|
||||||
|
var cols []dbdrivers.Column
|
||||||
|
|
||||||
|
r := AutoIncPrimaryKey(cols, pkey)
|
||||||
|
if r != "" {
|
||||||
|
t.Errorf("Expected empty string, got %s", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
pkey = &dbdrivers.PrimaryKey{
|
||||||
|
Columns: []string{
|
||||||
|
"col1", "auto",
|
||||||
|
},
|
||||||
|
Name: "",
|
||||||
|
}
|
||||||
|
|
||||||
|
cols = []dbdrivers.Column{
|
||||||
|
{
|
||||||
|
Name: "thing",
|
||||||
|
IsNullable: true,
|
||||||
|
Type: "int64",
|
||||||
|
Default: "nextval('abc'::regclass)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "stuff",
|
||||||
|
IsNullable: false,
|
||||||
|
Type: "string",
|
||||||
|
Default: "nextval('abc'::regclass)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "other",
|
||||||
|
IsNullable: false,
|
||||||
|
Type: "int64",
|
||||||
|
Default: "nextval",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
r = AutoIncPrimaryKey(cols, pkey)
|
||||||
|
if r != "" {
|
||||||
|
t.Errorf("Expected empty string, got %s", r)
|
||||||
|
}
|
||||||
|
|
||||||
|
cols = append(cols, dbdrivers.Column{
|
||||||
|
Name: "auto",
|
||||||
|
IsNullable: false,
|
||||||
|
Type: "int64",
|
||||||
|
Default: "nextval('abc'::regclass)",
|
||||||
|
})
|
||||||
|
|
||||||
|
r = AutoIncPrimaryKey(cols, pkey)
|
||||||
|
if r != "auto" {
|
||||||
|
t.Errorf("Expected empty string, got %s", r)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenerateParamFlags(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
x := GenerateParamFlags(5, 1)
|
||||||
|
want := "$1,$2,$3,$4,$5"
|
||||||
|
if want != x {
|
||||||
|
t.Errorf("want %s, got %s", want, x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestSingular(t *testing.T) {
|
func TestSingular(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -263,3 +330,41 @@ func TestWherePrimaryKey(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFilterColumnsByDefault(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
cols := []dbdrivers.Column{
|
||||||
|
{
|
||||||
|
Name: "col1",
|
||||||
|
Default: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "col2",
|
||||||
|
Default: "things",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "col3",
|
||||||
|
Default: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "col4",
|
||||||
|
Default: "things2",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := FilterColumnsByDefault(cols, false)
|
||||||
|
if res != `"col1","col3"` {
|
||||||
|
t.Errorf("Invalid result: %s", res)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = FilterColumnsByDefault(cols, true)
|
||||||
|
if res != `"col2","col4"` {
|
||||||
|
t.Errorf("Invalid result: %s", res)
|
||||||
|
}
|
||||||
|
|
||||||
|
res = FilterColumnsByDefault([]dbdrivers.Column{}, false)
|
||||||
|
if res != `` {
|
||||||
|
t.Errorf("Invalid result: %s", res)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in a new issue