From 3152eed170cb3087f1bb59c20147ef1168ac89db Mon Sep 17 00:00:00 2001 From: Patrick O'brien Date: Fri, 3 Jun 2016 07:07:51 +1000 Subject: [PATCH] Added debug mode, test singles * Where statement OR/AND support * Added null-extended library to use different data types * Added disable triggers function for test main --- boil/db.go | 14 +- boil/qs/query_mods.go | 9 ++ boil/query.go | 34 ++++- boil/reflect.go | 83 ++++++++--- cmds/config.go | 57 +++++++- cmds/output.go | 42 ++++++ cmds/sqlboiler.go | 25 +++- cmds/templates/insert.tpl | 6 +- cmds/templates/update.tpl | 7 +- .../main_test/postgres_main.tpl | 129 +++++++++--------- cmds/templates_test/singles/helper_funcs.tpl | 53 +++++++ cmds/types.go | 6 +- dbdrivers/postgres_driver.go | 26 +++- 13 files changed, 386 insertions(+), 105 deletions(-) create mode 100644 cmds/templates_test/singles/helper_funcs.tpl diff --git a/boil/db.go b/boil/db.go index 3cefcfb..fc20bc7 100644 --- a/boil/db.go +++ b/boil/db.go @@ -1,6 +1,9 @@ package boil -import "database/sql" +import ( + "database/sql" + "os" +) type Executor interface { Exec(query string, args ...interface{}) (sql.Result, error) @@ -21,6 +24,15 @@ type Creator interface { var currentDB Executor +// DebugMode is a flag controlling whether generated sql statements and +// debug information is outputted to the DebugWriter handle +// +// NOTE: This should be disabled in production to avoid leaking sensitive data +var DebugMode = false + +// DebugWriter is where the debug output will be sent if DebugMode is true +var DebugWriter = os.Stdout + func Begin() (Transactor, error) { creator, ok := currentDB.(Creator) if !ok { diff --git a/boil/qs/query_mods.go b/boil/qs/query_mods.go index 4395663..7d25ef9 100644 --- a/boil/qs/query_mods.go +++ b/boil/qs/query_mods.go @@ -10,6 +10,15 @@ func Apply(q *boil.Query, mods ...QueryMod) { } } +func Or(whereMods ...QueryMod) QueryMod { + return func(q *boil.Query) { + if len(whereMods) < 2 { + // error, needs to be at least 2 for an or + } + // add the where mods to query with or seperators + } +} + func Limit(limit int) QueryMod { return func(q *boil.Query) { boil.SetLimit(q, limit) diff --git a/boil/query.go b/boil/query.go index d1a51e4..78bc3d0 100644 --- a/boil/query.go +++ b/boil/query.go @@ -8,8 +8,9 @@ import ( ) type where struct { - clause string - args []interface{} + clause string + orSeperator bool + args []interface{} } type join struct { @@ -65,36 +66,65 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { buf.WriteString(" FROM ") fmt.Fprintf(buf, `"%s"`, q.table) + buf.WriteByte(';') return buf, []interface{}{} } func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { buf := &bytes.Buffer{} + buf.WriteString("DELETE FROM ") + fmt.Fprintf(buf, `"%s"`, q.table) + + if len(q.where) > 0 { + for i := 0; i < len(q.where); i++ { + buf.WriteString(fmt.Sprintf(` WHERE %s`, q.where[i].clause)) + if i != len(q.where)-1 { + if q.where[i].orSeperator { + buf.WriteString(` OR `) + } else { + buf.WriteString(` AND `) + } + } + } + } + + buf.WriteByte(';') + return buf, nil } func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { buf := &bytes.Buffer{} + buf.WriteByte(';') return buf, nil } // ExecQuery executes a query that does not need a row returned func ExecQuery(q *Query) (sql.Result, error) { qs, args := buildQuery(q) + if DebugMode { + fmt.Fprintln(DebugWriter, qs) + } return q.executor.Exec(qs, args...) } // ExecQueryOne executes the query for the One finisher and returns a row func ExecQueryOne(q *Query) *sql.Row { qs, args := buildQuery(q) + if DebugMode { + fmt.Fprintln(DebugWriter, qs) + } return q.executor.QueryRow(qs, args) } // ExecQueryAll executes the query for the All finisher and returns multiple rows func ExecQueryAll(q *Query) (*sql.Rows, error) { qs, args := buildQuery(q) + if DebugMode { + fmt.Fprintln(DebugWriter, qs) + } return q.executor.Query(qs, args) } diff --git a/boil/reflect.go b/boil/reflect.go index 5984cbf..47ed838 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -3,22 +3,33 @@ package boil import ( "database/sql" "fmt" + "math" "math/rand" "reflect" "sort" "time" - "github.com/guregu/null" "github.com/pobri19/sqlboiler/strmangle" + "gopkg.in/BlackBaronsTux/null-extended.v1" ) var ( - typeNullInt = reflect.TypeOf(null.Int{}) - typeNullFloat = reflect.TypeOf(null.Float{}) - typeNullString = reflect.TypeOf(null.String{}) - typeNullBool = reflect.TypeOf(null.Bool{}) - typeNullTime = reflect.TypeOf(null.Time{}) - typeTime = reflect.TypeOf(time.Time{}) + typeNullFloat32 = reflect.TypeOf(null.Float32{}) + typeNullFloat64 = reflect.TypeOf(null.Float64{}) + typeNullInt = reflect.TypeOf(null.Int{}) + typeNullInt8 = reflect.TypeOf(null.Int8{}) + typeNullInt16 = reflect.TypeOf(null.Int16{}) + typeNullInt32 = reflect.TypeOf(null.Int32{}) + typeNullInt64 = reflect.TypeOf(null.Int64{}) + typeNullUint = reflect.TypeOf(null.Uint{}) + typeNullUint8 = reflect.TypeOf(null.Uint8{}) + typeNullUint16 = reflect.TypeOf(null.Uint16{}) + typeNullUint32 = reflect.TypeOf(null.Uint32{}) + typeNullUint64 = reflect.TypeOf(null.Uint64{}) + typeNullString = reflect.TypeOf(null.String{}) + typeNullBool = reflect.TypeOf(null.Bool{}) + typeNullTime = reflect.TypeOf(null.Time{}) + typeTime = reflect.TypeOf(time.Time{}) ) // Bind executes the query and inserts the @@ -163,27 +174,65 @@ func randomizeField(field reflect.Value) error { if kind == reflect.Struct { switch typ { - case typeNullInt: - newVal = null.NewInt(rand.Int63(), rand.Intn(2) == 1) - case typeNullFloat: - newVal = null.NewFloat(rand.Float64(), rand.Intn(2) == 1) case typeNullBool: newVal = null.NewBool(rand.Intn(2) == 1, rand.Intn(2) == 1) case typeNullString: newVal = null.NewString(randStr(5+rand.Intn(25)), rand.Intn(2) == 1) case typeNullTime: - newVal = null.NewTime(time.Unix(rand.Int63(), 0), rand.Intn(2) == 1) + newVal = null.NewTime(time.Now().Add(time.Duration(rand.Intn((int(time.Hour * 24 * 10))))), rand.Intn(2) == 1) case typeTime: - newVal = time.Unix(rand.Int63(), 0) + newVal = time.Now().Add(time.Duration(rand.Intn((int(time.Hour * 24 * 10))))) + case typeNullFloat32: + newVal = null.NewFloat32(rand.Float32(), rand.Intn(2) == 1) + case typeNullFloat64: + newVal = null.NewFloat64(rand.Float64(), rand.Intn(2) == 1) + case typeNullInt: + newVal = null.NewInt(rand.Int(), rand.Intn(2) == 1) + case typeNullInt8: + newVal = null.NewInt8(int8(rand.Intn(int(math.MaxInt8))), rand.Intn(2) == 1) + case typeNullInt16: + newVal = null.NewInt16(int16(rand.Intn(int(math.MaxInt16))), rand.Intn(2) == 1) + case typeNullInt32: + newVal = null.NewInt32(rand.Int31(), rand.Intn(2) == 1) + case typeNullInt64: + newVal = null.NewInt64(rand.Int63(), rand.Intn(2) == 1) + case typeNullUint: + newVal = null.NewUint(uint(rand.Int()), rand.Intn(2) == 1) + case typeNullUint8: + newVal = null.NewUint8(uint8(rand.Intn(int(math.MaxInt8))), rand.Intn(2) == 1) + case typeNullUint16: + newVal = null.NewUint16(uint16(rand.Intn(int(math.MaxInt16))), rand.Intn(2) == 1) + case typeNullUint32: + newVal = null.NewUint32(uint32(rand.Int31()), rand.Intn(2) == 1) + case typeNullUint64: + newVal = null.NewUint64(uint64(rand.Int63()), rand.Intn(2) == 1) } } else { switch kind { - case reflect.Int: - newVal = rand.Int() - case reflect.Int64: - newVal = rand.Int63() + case reflect.Float32: + newVal = rand.Float32() case reflect.Float64: newVal = rand.Float64() + case reflect.Int: + newVal = rand.Int() + case reflect.Int8: + newVal = int8(rand.Intn(int(math.MaxInt8))) + case reflect.Int16: + newVal = int16(rand.Intn(int(math.MaxInt16))) + case reflect.Int32: + newVal = rand.Int31() + case reflect.Int64: + newVal = rand.Int63() + case reflect.Uint: + newVal = uint(rand.Int()) + case reflect.Uint8: + newVal = uint8(rand.Intn(int(math.MaxInt8))) + case reflect.Uint16: + newVal = uint16(rand.Intn(int(math.MaxInt16))) + case reflect.Uint32: + newVal = uint32(rand.Int31()) + case reflect.Uint64: + newVal = uint64(rand.Int63()) case reflect.Bool: var b bool if rand.Intn(2) == 1 { diff --git a/cmds/config.go b/cmds/config.go index af49a9c..c400bab 100644 --- a/cmds/config.go +++ b/cmds/config.go @@ -12,20 +12,50 @@ import ( // sqlBoilerTypeImports imports are only included in the template output if the database // requires one of the following special types. Check TranslateColumnType to see the type assignments. var sqlBoilerTypeImports = map[string]imports{ + "null.Float32": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Float64": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, "null.Int": imports{ - thirdparty: importList{`"gopkg.in/guregu/null.v3"`}, + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Int8": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Int16": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Int32": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Int64": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Uint": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Uint8": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Uint16": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Uint32": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, + }, + "null.Uint64": imports{ + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, }, "null.String": imports{ - thirdparty: importList{`"gopkg.in/guregu/null.v3"`}, + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, }, "null.Bool": imports{ - thirdparty: importList{`"gopkg.in/guregu/null.v3"`}, - }, - "null.Float": imports{ - thirdparty: importList{`"gopkg.in/guregu/null.v3"`}, + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, }, "null.Time": imports{ - thirdparty: importList{`"gopkg.in/guregu/null.v3"`}, + thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`}, }, "time.Time": imports{ standard: importList{`"time"`}, @@ -65,6 +95,19 @@ var sqlBoilerTestImports = imports{ }, } +var sqlBoilerSinglesTestImports = map[string]imports{ + "helper_funcs": imports{ + standard: importList{ + `"crypto/md5"`, + `"fmt"`, + `"os"`, + `"strconv"`, + `"math/rand"`, + }, + thirdparty: importList{}, + }, +} + var sqlBoilerTestMainImports = map[string]imports{ "postgres": imports{ standard: importList{ diff --git a/cmds/output.go b/cmds/output.go index c8b828c..edacd0d 100644 --- a/cmds/output.go +++ b/cmds/output.go @@ -111,6 +111,42 @@ func generateSinglesOutput(cmdData *CmdData) error { return nil } +func generateSinglesTestOutput(cmdData *CmdData) error { + if cmdData.SingleTestTemplates == nil { + return errors.New("No single test templates located for generation") + } + + tplData := &tplData{ + PkgName: cmdData.PkgName, + DriverName: cmdData.DriverName, + } + + for _, template := range cmdData.SingleTestTemplates { + var imps imports + + resp, err := generateTemplate(template, tplData) + if err != nil { + return fmt.Errorf("Error generating test template %s: %s", template.Name(), err) + } + + fName := template.Name() + ext := filepath.Ext(fName) + fName = fName[0 : len(fName)-len(ext)] + + imps.standard = sqlBoilerSinglesTestImports[fName].standard + imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty + + fName = fName + "_test.go" + + err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp}) + if err != nil { + return err + } + } + + return nil +} + func generateTestMainOutput(cmdData *CmdData) error { if cmdData.TestMainTemplate == nil { return errors.New("No TestMain template located for generation") @@ -122,9 +158,15 @@ func generateTestMainOutput(cmdData *CmdData) error { imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty + var tables []string + for _, v := range cmdData.Tables { + tables = append(tables, v.Name) + } + tplData := &tplData{ PkgName: cmdData.PkgName, DriverName: cmdData.DriverName, + Tables: tables, } resp, err := generateTemplate(cmdData.TestMainTemplate, tplData) diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go index a95e2ff..3560bcd 100644 --- a/cmds/sqlboiler.go +++ b/cmds/sqlboiler.go @@ -14,10 +14,12 @@ import ( ) const ( - templatesDirectory = "/cmds/templates" - templatesSinglesDirectory = "/cmds/templates/singles" - templatesTestDirectory = "/cmds/templates_test" - templatesTestMainDirectory = "/cmds/templates_test/main_test" + templatesDirectory = "/cmds/templates" + templatesSinglesDirectory = "/cmds/templates/singles" + + templatesTestDirectory = "/cmds/templates_test" + templatesSinglesTestDirectory = "/cmds/templates_test/singles" + templatesTestMainDirectory = "/cmds/templates_test/main_test" ) // LoadTemplates loads all template folders into the cmdData object. @@ -38,6 +40,11 @@ func initTemplates(cmdData *CmdData) error { return err } + cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory) + if err != nil { + return err + } + filename := cmdData.DriverName + "_main.tpl" cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename) if err != nil { @@ -118,14 +125,18 @@ func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error { // run executes the sqlboiler templates and outputs them to files. func (c *CmdData) run(includeTests bool) error { + if err := generateSinglesOutput(c); err != nil { + return fmt.Errorf("Unable to generate single templates output: %s", err) + } + if includeTests { if err := generateTestMainOutput(c); err != nil { return fmt.Errorf("Unable to generate TestMain output: %s", err) } - } - if err := generateSinglesOutput(c); err != nil { - return fmt.Errorf("Unable to generate single templates output: %s", err) + if err := generateSinglesTestOutput(c); err != nil { + return fmt.Errorf("Unable to generate single test templates output: %s", err) + } } for _, table := range c.Tables { diff --git a/cmds/templates/insert.tpl b/cmds/templates/insert.tpl index 7f29c95..9c83390 100644 --- a/cmds/templates/insert.tpl +++ b/cmds/templates/insert.tpl @@ -63,12 +63,16 @@ func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... strin {{else}} if len(returnColumns) != 0 { ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ",")) - err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...)) + err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...)...) } else { _, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}}) } {{end}} + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, ins) + } + if err != nil { return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err) } diff --git a/cmds/templates/update.tpl b/cmds/templates/update.tpl index 75d2d28..72d5735 100644 --- a/cmds/templates/update.tpl +++ b/cmds/templates/update.tpl @@ -35,13 +35,18 @@ func (o *{{$tableNameSingular}}) UpdateAtX(exec boil.Executor, {{primaryKeyFuncS } var err error + var query string if len(whitelist) != 0 { - query := fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{commaList .Table.PKey.Columns}})) + query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{commaList .Table.PKey.Columns}})) _, err = exec.Exec(query, boil.GetStructValues(o, whitelist...), {{paramsPrimaryKey "o." .Table.PKey.Columns true}}) } else { return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build a whitelist for row: %s", err) } + if boil.DebugMode { + fmt.Fprintln(boil.DebugWriter, query) + } + if err != nil { return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err) } diff --git a/cmds/templates_test/main_test/postgres_main.tpl b/cmds/templates_test/main_test/postgres_main.tpl index 6650c53..aaae3e6 100644 --- a/cmds/templates_test/main_test/postgres_main.tpl +++ b/cmds/templates_test/main_test/postgres_main.tpl @@ -16,46 +16,81 @@ var testCfg *Config var dbConn *sql.DB func TestMain(m *testing.M) { - rand.Seed(time.Now().UnixNano()) + // Set the DebugMode to true so we can see generated sql statements + boil.DebugMode = true - err := setup() + rand.Seed(time.Now().UnixNano()) + var err error + + err = setup() if err != nil { - fmt.Println(err) + fmt.Printf("Unable to execute setup: %s", err) os.Exit(-1) } + err = disableTriggers() + if err != nil { + fmt.Printf("Unable to disable triggers: %s", err) + } boil.SetDB(dbConn) code := m.Run() err = teardown() if err != nil { - fmt.Println(err) + fmt.Printf("Unable to execute teardown: %s", err) os.Exit(-1) } os.Exit(code) } -// teardown switches its connection to the template1 database temporarily -// so that it can drop the test database and the test user. -// The template1 database should be present on all default postgres installations. +// disableTriggers is used to disable foreign key constraints for every table. +// If this is not used we cannot test inserts due to foreign key constraint errors. +func disableTriggers() error { + var stmts []string + + {{range .Tables}} + stmts = append(stmts, `ALTER TABLE {{.}} DISABLE TRIGGER ALL;`) + {{- end}} + + if len(stmts) == 0 { + return nil + } + + var err error + for _, s := range stmts { + _, err = dbConn.Exec(s) + if err != nil { + return err + } + } + + return nil +} + +// teardown executes cleanup tasks when the tests finish running func teardown() error { - err := dbConn.Close() + err := dropTestDB() + return err +} + +// dropTestDB switches its connection to the template1 database temporarily +// so that it can drop the test database without causing "in use" conflicts. +// The template1 database should be present on all default postgres installations. +func dropTestDB() error { + var err error + if dbConn != nil { + if err = dbConn.Close(); err != nil { + return err + } + } + + dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port) if err != nil { return err } - dbConn, err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, "template1", cfg.Postgres.Host, cfg.Postgres.Port) - if err != nil { - return err - } - - _, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE %s;`, testCfg.Postgres.DBName)) - if err != nil { - return err - } - - _, err = dbConn.Exec(fmt.Sprintf(`DROP USER %s;`, testCfg.Postgres.User)) + _, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName)) if err != nil { return err } @@ -71,16 +106,6 @@ func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) { return sql.Open("postgres", connStr) } -func randSeq(n int) string { - var letters = []rune("abcdefghijklmnopqrstuvwxyz") - - randStr := make([]rune, n) - for i := range randStr { - randStr[i] = letters[rand.Intn(len(letters))] - } - return string(randStr) -} - func LoadConfigFile(filename string) error { _, err := toml.DecodeFile(filename, &cfg) @@ -105,13 +130,21 @@ func setup() error { return fmt.Errorf("Unable to load config file: %s", err) } + testDBName := getDBNameHash(cfg.Postgres.DBName) + // Create a randomized test configuration object. testCfg = &Config{} testCfg.Postgres.Host = cfg.Postgres.Host testCfg.Postgres.Port = cfg.Postgres.Port - testCfg.Postgres.User = randSeq(20) - testCfg.Postgres.Pass = randSeq(20) - testCfg.Postgres.DBName = cfg.Postgres.DBName + "_" + randSeq(10) + testCfg.Postgres.User = cfg.Postgres.User + testCfg.Postgres.Pass = cfg.Postgres.Pass + testCfg.Postgres.DBName = testDBName + + err = dropTestDB() + if err != nil { + fmt.Printf("%#v\n", err) + return err + } fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema") if err != nil { @@ -166,35 +199,24 @@ func setup() error { return err } - // Create the randomly generated database test user - if err = createTestUser(dbConn); err != nil { - return err - } - // Create the randomly generated database _, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName)) if err != nil { return err } - // Assign the randomly generated db test user to the generated test db - _, err = dbConn.Exec(fmt.Sprintf(`ALTER DATABASE %s OWNER TO %s;`, testCfg.Postgres.DBName, testCfg.Postgres.User)) - if err != nil { - return err - } - - // Close the old connection so we can reconnect with the restricted access generated user + // Close the old connection so we can reconnect to the test database if err = dbConn.Close(); err != nil { return err } - // Connect to the generated test db with the restricted privilege generated user + // Connect to the generated test db dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port) if err != nil { return err } - // Write the generated user password to a tmp file for pg_dump + // Write the test config credentials to a tmp file for pg_dump testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s", testCfg.Postgres.Host, testCfg.Postgres.Port, @@ -233,18 +255,3 @@ func setup() error { return nil } - -// createTestUser creates a temporary database user with restricted privileges -func createTestUser(db *sql.DB) error { - now := time.Now().Add(time.Hour * 24 * 2) - valid := now.Format("2006-1-2") - - query := fmt.Sprintf(`CREATE USER %s WITH PASSWORD '%s' VALID UNTIL '%s';`, - testCfg.Postgres.User, - testCfg.Postgres.Pass, - valid, - ) - - _, err := dbConn.Exec(query) - return err -} diff --git a/cmds/templates_test/singles/helper_funcs.tpl b/cmds/templates_test/singles/helper_funcs.tpl new file mode 100644 index 0000000..e879913 --- /dev/null +++ b/cmds/templates_test/singles/helper_funcs.tpl @@ -0,0 +1,53 @@ +var dbNameRand *rand.Rand + +func initDBNameRand(input string) { + sum := md5.Sum([]byte(input)) + + var sumInt string + for _, v := range sum { + sumInt = sumInt + strconv.Itoa(int(v)) + } + + // Cut integer to 18 digits to ensure no int64 overflow. + sumInt = sumInt[:18] + + sumTmp := sumInt + for i, v := range sumInt { + if v == '0' { + sumTmp = sumInt[i+1:] + continue + } + + break + } + + sumInt = sumTmp + + randSeed, err := strconv.ParseInt(sumInt, 0, 64) + if err != nil { + fmt.Printf("Unable to parse sumInt: %s", err) + os.Exit(-1) + } + + dbNameRand = rand.New(rand.NewSource(randSeed)) +} + +var alphabetChars = "abcdefghijklmnopqrstuvwxyz" +func randStr(length int) string { + c := len(alphabetChars) + + output := make([]rune, length) + for i := 0; i < length; i++ { + output[i] = rune(alphabetChars[dbNameRand.Intn(c)]) + } + + return string(output) +} + +// getDBNameHash takes a database name in, and generates +// a random string using the database name as the rand Seed. +// getDBNameHash is used to generate unique test database names. +func getDBNameHash(input string) string { + initDBNameRand(input) + return randStr(40) +} diff --git a/cmds/types.go b/cmds/types.go index e1638ee..dbde66b 100644 --- a/cmds/types.go +++ b/cmds/types.go @@ -28,7 +28,10 @@ type CmdData struct { // SingleTemplates are only created once, not per table SingleTemplates templater - TestTemplates templater + TestTemplates templater + // SingleTestTemplates are only created once, not per table + SingleTestTemplates templater + //TestMainTemplate is only created once, not per table TestMainTemplate *template.Template } @@ -37,6 +40,7 @@ type tplData struct { Table dbdrivers.Table DriverName string PkgName string + Tables []string } type importList []string diff --git a/dbdrivers/postgres_driver.go b/dbdrivers/postgres_driver.go index 43017b8..90cbc74 100644 --- a/dbdrivers/postgres_driver.go +++ b/dbdrivers/postgres_driver.go @@ -210,23 +210,37 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error) func (p *PostgresDriver) TranslateColumnType(c Column) Column { if c.IsNullable { switch c.Type { - case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial": - c.Type = "null.Int" + case "bigint", "bigserial": + c.Type = "null.Int64" + case "integer", "serial": + c.Type = "null.Int32" + case "smallint", "smallserial": + c.Type = "null.Int16" + case "decimal", "numeric", "double precision", "money": + c.Type = "null.Float64" + case "real": + c.Type = "null.Float32" case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": c.Type = "null.String" case "boolean": c.Type = "null.Bool" case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone": c.Type = "null.Time" - case "double precision", "money", "numeric", "real": - c.Type = "null.Float" default: c.Type = "null.String" } } else { switch c.Type { - case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial": + case "bigint", "bigserial": c.Type = "int64" + case "integer", "serial": + c.Type = "int32" + case "smallint", "smallserial": + c.Type = "int16" + case "decimal", "numeric", "double precision", "money": + c.Type = "float64" + case "real": + c.Type = "float32" case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml": c.Type = "string" case "bytea": @@ -235,8 +249,6 @@ func (p *PostgresDriver) TranslateColumnType(c Column) Column { c.Type = "bool" case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone": c.Type = "time.Time" - case "double precision", "money", "numeric", "real": - c.Type = "float64" default: c.Type = "string" }