From c2541ea56ef76e1e87f7986b3a44016472690688 Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Tue, 17 May 2016 20:00:56 +1000 Subject: [PATCH] Begun implementing all tests * Added randomizeStruct * Added under development warning to readme * Restructured the reflection stuff a bit * Added a testmangle.go file for template test functions --- README.md | 2 + boil/helpers.go | 7 +- boil/reflect.go | 229 ++++++++++++++++++ boil/reflect_test.go | 159 ++++++++++++ cmds/config.go | 7 + cmds/templates/insert.tpl | 8 +- cmds/templates_test/all.tpl | 53 ++++ .../main_test/postgres_main.tpl | 2 + dbdrivers/interface.go | 2 +- dbdrivers/postgres_driver.go | 12 +- strmangle/testmangle.go | 22 ++ strmangle/testmangle_test.go | 1 + 12 files changed, 492 insertions(+), 12 deletions(-) create mode 100644 boil/reflect.go create mode 100644 boil/reflect_test.go create mode 100644 strmangle/testmangle.go create mode 100644 strmangle/testmangle_test.go diff --git a/README.md b/README.md index a6e25f0..29bd6c1 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ +# STILL IN DEVELOPMENT. ETA RELEASE: 1 MONTH + # SQLBoiler [![GoDoc](https://godoc.org/github.com/pobri19/sqlboiler?status.svg)](https://godoc.org/github.com/pobri19/sqlboiler) diff --git a/boil/helpers.go b/boil/helpers.go index c1103dd..076a19a 100644 --- a/boil/helpers.go +++ b/boil/helpers.go @@ -56,7 +56,7 @@ func SetIntersect(a []string, b []string) []string { func NonZeroDefaultSet(defaults []string, obj interface{}) []string { c := make([]string, 0, len(defaults)) - val := reflect.ValueOf(obj) + val := reflect.Indirect(reflect.ValueOf(obj)) for _, d := range defaults { fieldName := strmangle.TitleCase(d) @@ -128,10 +128,7 @@ func WherePrimaryKeyIn(numRows int, keyNames ...string) string { func SelectNames(results interface{}) string { var names []string - structValue := reflect.ValueOf(results) - if structValue.Kind() == reflect.Ptr { - structValue = structValue.Elem() - } + structValue := reflect.Indirect(reflect.ValueOf(results)) structType := structValue.Type() for i := 0; i < structValue.NumField(); i++ { diff --git a/boil/reflect.go b/boil/reflect.go new file mode 100644 index 0000000..5984cbf --- /dev/null +++ b/boil/reflect.go @@ -0,0 +1,229 @@ +package boil + +import ( + "database/sql" + "fmt" + "math/rand" + "reflect" + "sort" + "time" + + "github.com/guregu/null" + "github.com/pobri19/sqlboiler/strmangle" +) + +var ( + typeNullInt = reflect.TypeOf(null.Int{}) + typeNullFloat = reflect.TypeOf(null.Float{}) + typeNullString = reflect.TypeOf(null.String{}) + typeNullBool = reflect.TypeOf(null.Bool{}) + typeNullTime = reflect.TypeOf(null.Time{}) + typeTime = reflect.TypeOf(time.Time{}) +) + +// Bind executes the query and inserts the +// result into the passed in object pointer +func (q *Query) Bind(obj interface{}) error { + return nil +} + +// BindOne inserts the returned row columns into the +// passed in object pointer +func BindOne(row *sql.Row, obj interface{}) error { + return nil +} + +// BindAll inserts the returned rows columns into the +// passed in slice of object pointers +func BindAll(rows *sql.Rows, obj interface{}) error { + return nil +} + +func checkType(obj interface{}) (reflect.Type, bool, error) { + val := reflect.ValueOf(obj) + typ := val.Type() + kind := val.Kind() + + if kind != reflect.Ptr { + return nil, false, fmt.Errorf("Bind must be given pointers to structs but got type: %s, kind: %s", typ.String(), kind) + } + + typ = typ.Elem() + kind = typ.Kind() + isSlice := false + + switch kind { + case reflect.Slice: + typ = typ.Elem() + kind = typ.Kind() + isSlice = true + case reflect.Struct: + return typ, isSlice, nil + default: + return nil, false, fmt.Errorf("Bind was given an invalid object must be []*T or *T but got type: %s, kind: %s", typ.String(), kind) + } + + if kind != reflect.Ptr { + return nil, false, fmt.Errorf("Bind must be given pointers to structs but got type: %s, kind: %s", typ.String(), kind) + } + + typ = typ.Elem() + kind = typ.Kind() + + if kind != reflect.Struct { + return nil, false, fmt.Errorf("Bind must be a struct but got type: %s, kind: %s", typ.String(), kind) + } + + return typ, isSlice, nil +} + +// GetStructValues returns the values (as interface) of the matching columns in obj +func GetStructValues(obj interface{}, columns ...string) []interface{} { + ret := make([]interface{}, len(columns)) + val := reflect.Indirect(reflect.ValueOf(obj)) + + for i, c := range columns { + field := val.FieldByName(strmangle.TitleCase(c)) + ret[i] = field.Interface() + } + + return ret +} + +// GetStructPointers returns a slice of pointers to the matching columns in obj +func GetStructPointers(obj interface{}, columns ...string) []interface{} { + val := reflect.ValueOf(obj).Elem() + ret := make([]interface{}, len(columns)) + + for i, c := range columns { + field := val.FieldByName(strmangle.TitleCase(c)) + if !field.IsValid() { + panic(fmt.Sprintf("Could not find field on struct %T for field %s", obj, strmangle.TitleCase(c))) + } + + field = field.Addr() + ret[i] = field.Interface() + } + + return ret +} + +// RandomizeStruct takes an object and fills it with random data. +// It will ignore the fields in the blacklist. +func RandomizeStruct(str interface{}, blacklist ...string) error { + // Don't modify blacklist + copyBlacklist := make([]string, len(blacklist)) + copy(copyBlacklist, blacklist) + blacklist = copyBlacklist + + sort.Strings(blacklist) + + // Check if it's pointer + value := reflect.ValueOf(str) + kind := value.Kind() + if kind != reflect.Ptr { + return fmt.Errorf("can only randomize pointers to structs, given: %T", str) + } + + // Check if it's a struct + value = value.Elem() + kind = value.Kind() + if kind != reflect.Struct { + return fmt.Errorf("can only randomize pointers to structs, given: %T", str) + } + + typ := value.Type() + nFields := value.NumField() + + // Iterate through fields, randomizing + for i := 0; i < nFields; i++ { + fieldVal := value.Field(i) + fieldTyp := typ.Field(i) + + found := sort.Search(len(blacklist), func(i int) bool { + return blacklist[i] == fieldTyp.Name + }) + if found != len(blacklist) { + continue + } + + if err := randomizeField(fieldVal); err != nil { + return err + } + } + + return nil +} + +func randomizeField(field reflect.Value) error { + kind := field.Kind() + typ := field.Type() + + var newVal interface{} + + if kind == reflect.Struct { + switch typ { + case typeNullInt: + newVal = null.NewInt(rand.Int63(), rand.Intn(2) == 1) + case typeNullFloat: + newVal = null.NewFloat(rand.Float64(), rand.Intn(2) == 1) + case typeNullBool: + newVal = null.NewBool(rand.Intn(2) == 1, rand.Intn(2) == 1) + case typeNullString: + newVal = null.NewString(randStr(5+rand.Intn(25)), rand.Intn(2) == 1) + case typeNullTime: + newVal = null.NewTime(time.Unix(rand.Int63(), 0), rand.Intn(2) == 1) + case typeTime: + newVal = time.Unix(rand.Int63(), 0) + } + } else { + switch kind { + case reflect.Int: + newVal = rand.Int() + case reflect.Int64: + newVal = rand.Int63() + case reflect.Float64: + newVal = rand.Float64() + case reflect.Bool: + var b bool + if rand.Intn(2) == 1 { + b = true + } + newVal = b + case reflect.String: + newVal = randStr(5 + rand.Intn(20)) + case reflect.Slice: + sliceVal := typ.Elem() + if sliceVal.Kind() != reflect.Uint8 { + return fmt.Errorf("unsupported slice type: %T", typ.String()) + } + newVal = randByteSlice(5 + rand.Intn(20)) + default: + return fmt.Errorf("unsupported type: %T", typ.String()) + } + } + + field.Set(reflect.ValueOf(newVal)) + + return nil +} + +const alphabet = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + +func randStr(ln int) string { + str := make([]byte, ln) + for i := 0; i < ln; i++ { + str[i] = byte(alphabet[rand.Intn(len(alphabet))]) + } + + return string(str) +} + +func randByteSlice(ln int) []byte { + str := make([]byte, ln) + for i := 0; i < ln; i++ { + str[i] = byte(rand.Intn(256)) + } + + return str +} diff --git a/boil/reflect_test.go b/boil/reflect_test.go new file mode 100644 index 0000000..55346e8 --- /dev/null +++ b/boil/reflect_test.go @@ -0,0 +1,159 @@ +package boil + +import ( + "testing" + "time" + + "github.com/guregu/null" +) + +func TestGetStructValues(t *testing.T) { + t.Parallel() + timeThing := time.Now() + o := struct { + TitleThing string + Name string + ID int + Stuff int + Things int + Time time.Time + NullBool null.Bool + }{ + TitleThing: "patrick", + Stuff: 10, + Things: 0, + Time: timeThing, + NullBool: null.NewBool(true, false), + } + + vals := GetStructValues(&o, "title_thing", "name", "id", "stuff", "things", "time", "null_bool") + if vals[0].(string) != "patrick" { + t.Errorf("Want test, got %s", vals[0]) + } + if vals[1].(string) != "" { + t.Errorf("Want empty string, got %s", vals[1]) + } + if vals[2].(int) != 0 { + t.Errorf("Want 0, got %d", vals[2]) + } + if vals[3].(int) != 10 { + t.Errorf("Want 10, got %d", vals[3]) + } + if vals[4].(int) != 0 { + t.Errorf("Want 0, got %d", vals[4]) + } + if !vals[5].(time.Time).Equal(timeThing) { + t.Errorf("Want %s, got %s", o.Time, vals[5]) + } + if !vals[6].(null.Bool).IsZero() { + t.Errorf("Want %v, got %v", o.NullBool, vals[6]) + } +} + +func TestGetStructPointers(t *testing.T) { + t.Parallel() + + o := struct { + Title string + ID *int + }{ + Title: "patrick", + } + + ptrs := GetStructPointers(&o, "title", "id") + *ptrs[0].(*string) = "test" + if o.Title != "test" { + t.Errorf("Expected test, got %s", o.Title) + } + x := 5 + *ptrs[1].(**int) = &x + if *o.ID != 5 { + t.Errorf("Expected 5, got %d", *o.ID) + } +} + +func TestCheckType(t *testing.T) { + t.Parallel() + + type Thing struct { + } + + validTest := []struct { + Input interface{} + IsSlice bool + TypeName string + }{ + {&[]*Thing{}, true, "boil.Thing"}, + {[]Thing{}, false, ""}, + {&[]Thing{}, false, ""}, + {Thing{}, false, ""}, + {new(int), false, ""}, + {5, false, ""}, + {&Thing{}, false, "boil.Thing"}, + } + + for i, test := range validTest { + typ, isSlice, err := checkType(test.Input) + if err != nil { + if len(test.TypeName) > 0 { + t.Errorf("%d) Type: %T %#v - should have succeded but got err: %v", i, test.Input, test.Input, err) + } + continue + } + + if isSlice != test.IsSlice { + t.Errorf("%d) Type: %T %#v - succeded but wrong isSlice value: %t, want %t", i, test.Input, test.Input, isSlice, test.IsSlice) + } + + if got := typ.String(); got != test.TypeName { + t.Errorf("%d) Type: %T %#v - succeded but wrong type name: %s, want: %s", i, test.Input, test.Input, got, test.TypeName) + } + } +} + +func TestRandomizeStruct(t *testing.T) { + var testStruct = struct { + Int int + Int64 int64 + Float64 float64 + Bool bool + Time time.Time + String string + ByteSlice []byte + + Ignore int + + NullInt null.Int + NullFloat null.Float + NullBool null.Bool + NullString null.String + NullTime null.Time + }{} + + err := RandomizeStruct(&testStruct, "Ignore") + if err != nil { + t.Fatal(err) + } + + if testStruct.Ignore != 0 { + t.Error("blacklisted value was filled in:", testStruct.Ignore) + } + + if testStruct.Int == 0 && + testStruct.Int64 == 0 && + testStruct.Float64 == 0 && + testStruct.Bool == false && + testStruct.Time.IsZero() && + testStruct.String == "" && + testStruct.ByteSlice == nil { + t.Errorf("the regular values are not being randomized: %#v", testStruct) + } + + if testStruct.NullInt.Valid == false && + testStruct.NullFloat.Valid == false && + testStruct.NullBool.Valid == false && + testStruct.NullString.Valid == false && + testStruct.NullTime.Valid == false { + t.Errorf("the null values are not being randomized: %#v", testStruct) + } +} diff --git a/cmds/config.go b/cmds/config.go index 55a5f37..af49a9c 100644 --- a/cmds/config.go +++ b/cmds/config.go @@ -60,6 +60,9 @@ var sqlBoilerTestImports = imports{ standard: importList{ `"testing"`, }, + thirdparty: importList{ + `"github.com/pobri19/sqlboiler/boil"`, + }, } var sqlBoilerTestMainImports = map[string]imports{ @@ -76,6 +79,7 @@ var sqlBoilerTestMainImports = map[string]imports{ `"math/rand"`, }, thirdparty: importList{ + `"github.com/pobri19/sqlboiler/boil"`, `"github.com/BurntSushi/toml"`, `_ "github.com/lib/pq"`, }, @@ -114,6 +118,9 @@ var sqlBoilerTemplateFuncs = template.FuncMap{ "filterColumnsByDefault": strmangle.FilterColumnsByDefault, "filterColumnsByAutoIncrement": strmangle.FilterColumnsByAutoIncrement, "autoIncPrimaryKey": strmangle.AutoIncPrimaryKey, + + "randDBStruct": strmangle.RandDBStruct, + "randDBStructSlice": strmangle.RandDBStructSlice, } // LoadConfigFile loads the toml config file into the cfg object diff --git a/cmds/templates/insert.tpl b/cmds/templates/insert.tpl index 237683a..7f29c95 100644 --- a/cmds/templates/insert.tpl +++ b/cmds/templates/insert.tpl @@ -32,7 +32,7 @@ func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... strin {{if supportsResultObject .DriverName}} if len(returnColumns) != 0 { - result, err := exec.Exec(ins, boil.GetStructValues(o, wl...)) + result, err := exec.Exec(ins, boil.GetStructValues(o, wl...)...) if err != nil { return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err) } @@ -40,7 +40,7 @@ func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... strin lastId, err := result.lastInsertId() if err != nil || lastId == 0 { sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s`, strings.Join(returnColumns, ","), boil.WhereClause(wl)) - rows, err := exec.Query(sel, boil.GetStructValues(o, wl...)) + rows, err := exec.Query(sel, boil.GetStructValues(o, wl...)...) if err != nil { return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err) } @@ -58,12 +58,12 @@ func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... strin sel := fmt.Sprintf(`SELECT %s FROM {{.Table.Name}} WHERE %s=$1`, strings.Join(returnColumns, ","), {{$varNameSingular}}AutoIncPrimaryKey, lastId) } } else { - _, err = exec.Exec(ins, boil.GetStructValues(o, wl...)) + _, err = exec.Exec(ins, boil.GetStructValues(o, wl...)...) } {{else}} if len(returnColumns) != 0 { ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) - err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)).Scan(boil.GetStructPointers(o, returnColumns...)) + err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...)) } else { _, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}}) } diff --git a/cmds/templates_test/all.tpl b/cmds/templates_test/all.tpl index 0a3d1eb..e4b7c45 100644 --- a/cmds/templates_test/all.tpl +++ b/cmds/templates_test/all.tpl @@ -4,5 +4,58 @@ {{- $varNamePlural := camelCasePlural .Table.Name -}} // {{$tableNamePlural}}All retrieves all records. func Test{{$tableNamePlural}}All(t *testing.T) { + var err error + r := make([]{{$tableNameSingular}}, 2) + + // insert two random columns to test DeleteAll + for i, v := range r { + err = boil.RandomizeStruct(&v) + if err != nil { + t.Errorf("%d: Unable to randomize {{$tableNameSingular}} struct: %s", i, err) + } + + err = v.Insert() + if err != nil { + t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", v, err) + } + } + + // Delete all rows to give a clean slate + err = {{$tableNamePlural}}().DeleteAll() + if err != nil { + t.Errorf("Unable to delete all from {{$tableNamePlural}}: %s", err) + } + + // Check number of rows in table to ensure DeleteAll was successful + var c int64 + c, err = {{$tableNamePlural}}().Count() + + if c != 0 { + t.Errorf("Expected {{.Table.Name}} table to be empty, but got %d rows", c) + } + + o := make([]{{$tableNameSingular}}, 3) + + for i, v := range o { + err = boil.RandomizeStruct(&v) + if err != nil { + t.Errorf("%d: Unable to randomize {{$tableNameSingular}} struct: %s", i, err) + } + + err = v.Insert() + if err != nil { + t.Errorf("Unable to insert {{$tableNameSingular}}:\n%#v\nErr: %s", v, err) + } + } + + // Attempt to retrieve all objects + res, err := {{$tableNamePlural}}().All() + if err != nil { + t.Errorf("Unable to retrieve all {{$tableNamePlural}}, err: %s", err) + } + + if len(res) != 3 { + t.Errorf("Expected 3 {{$tableNameSingular}} rows, got %d", len(res)) + } } diff --git a/cmds/templates_test/main_test/postgres_main.tpl b/cmds/templates_test/main_test/postgres_main.tpl index 8d056ce..6650c53 100644 --- a/cmds/templates_test/main_test/postgres_main.tpl +++ b/cmds/templates_test/main_test/postgres_main.tpl @@ -24,6 +24,7 @@ func TestMain(m *testing.M) { os.Exit(-1) } + boil.SetDB(dbConn) code := m.Run() err = teardown() @@ -37,6 +38,7 @@ func TestMain(m *testing.M) { // teardown switches its connection to the template1 database temporarily // so that it can drop the test database and the test user. +// The template1 database should be present on all default postgres installations. func teardown() error { err := dbConn.Close() if err != nil { diff --git a/dbdrivers/interface.go b/dbdrivers/interface.go index 9ddbeef..50f9653 100644 --- a/dbdrivers/interface.go +++ b/dbdrivers/interface.go @@ -71,8 +71,8 @@ func Tables(db Interface, names ...string) ([]Table, error) { t := Table{Name: name} if t.Columns, err = db.Columns(name); err != nil { + fmt.Println("Unable to get columns.") return nil, err - fmt.Println("Unable to get columnss.") } for i, c := range t.Columns { diff --git a/dbdrivers/postgres_driver.go b/dbdrivers/postgres_driver.go index 9624669..43017b8 100644 --- a/dbdrivers/postgres_driver.go +++ b/dbdrivers/postgres_driver.go @@ -91,9 +91,17 @@ func (p *PostgresDriver) Columns(tableName string) ([]Column, error) { defer rows.Close() for rows.Next() { var colName, colType, colDefault, isNullable string - if err := rows.Scan(&colName, &colType, &colDefault, &isNullable); err != nil { - return nil, err + var defaultPtr *string + if err := rows.Scan(&colName, &colType, &defaultPtr, &isNullable); err != nil { + return nil, fmt.Errorf("Unable to scan for table %s: %s", tableName, err) } + + if defaultPtr == nil { + colDefault = "" + } else { + colDefault = *defaultPtr + } + column := Column{ Name: colName, Type: colType, diff --git a/strmangle/testmangle.go b/strmangle/testmangle.go new file mode 100644 index 0000000..af70ec7 --- /dev/null +++ b/strmangle/testmangle.go @@ -0,0 +1,22 @@ +package strmangle + +import ( + "fmt" + "strings" + + "github.com/pobri19/sqlboiler/dbdrivers" +) + +func RandDBStruct(varName string, table dbdrivers.Table) string { + return "" +} + +func RandDBStructSlice(varName string, num int, table dbdrivers.Table) string { + var structs []string + for i := 0; i < num; i++ { + structs = append(structs, RandDBStruct(varName, table)) + } + + innerStructs := strings.Join(structs, ",") + return fmt.Sprintf("%s := %s{%s}", varName, TitleCasePlural(table.Name), innerStructs) +} diff --git a/strmangle/testmangle_test.go b/strmangle/testmangle_test.go new file mode 100644 index 0000000..2d1fc76 --- /dev/null +++ b/strmangle/testmangle_test.go @@ -0,0 +1 @@ +package strmangle