Added debug mode, test singles
* Where statement OR/AND support * Added null-extended library to use different data types * Added disable triggers function for test main
This commit is contained in:
parent
c2541ea56e
commit
3152eed170
13 changed files with 386 additions and 105 deletions
14
boil/db.go
14
boil/db.go
|
@ -1,6 +1,9 @@
|
|||
package boil
|
||||
|
||||
import "database/sql"
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
)
|
||||
|
||||
type Executor interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
|
@ -21,6 +24,15 @@ type Creator interface {
|
|||
|
||||
var currentDB Executor
|
||||
|
||||
// DebugMode is a flag controlling whether generated sql statements and
|
||||
// debug information is outputted to the DebugWriter handle
|
||||
//
|
||||
// NOTE: This should be disabled in production to avoid leaking sensitive data
|
||||
var DebugMode = false
|
||||
|
||||
// DebugWriter is where the debug output will be sent if DebugMode is true
|
||||
var DebugWriter = os.Stdout
|
||||
|
||||
func Begin() (Transactor, error) {
|
||||
creator, ok := currentDB.(Creator)
|
||||
if !ok {
|
||||
|
|
|
@ -10,6 +10,15 @@ func Apply(q *boil.Query, mods ...QueryMod) {
|
|||
}
|
||||
}
|
||||
|
||||
func Or(whereMods ...QueryMod) QueryMod {
|
||||
return func(q *boil.Query) {
|
||||
if len(whereMods) < 2 {
|
||||
// error, needs to be at least 2 for an or
|
||||
}
|
||||
// add the where mods to query with or seperators
|
||||
}
|
||||
}
|
||||
|
||||
func Limit(limit int) QueryMod {
|
||||
return func(q *boil.Query) {
|
||||
boil.SetLimit(q, limit)
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
|
||||
type where struct {
|
||||
clause string
|
||||
orSeperator bool
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
|
@ -65,36 +66,65 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
|||
buf.WriteString(" FROM ")
|
||||
fmt.Fprintf(buf, `"%s"`, q.table)
|
||||
|
||||
buf.WriteByte(';')
|
||||
return buf, []interface{}{}
|
||||
}
|
||||
|
||||
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteString("DELETE FROM ")
|
||||
fmt.Fprintf(buf, `"%s"`, q.table)
|
||||
|
||||
if len(q.where) > 0 {
|
||||
for i := 0; i < len(q.where); i++ {
|
||||
buf.WriteString(fmt.Sprintf(` WHERE %s`, q.where[i].clause))
|
||||
if i != len(q.where)-1 {
|
||||
if q.where[i].orSeperator {
|
||||
buf.WriteString(` OR `)
|
||||
} else {
|
||||
buf.WriteString(` AND `)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
buf.WriteByte(';')
|
||||
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
buf.WriteByte(';')
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
// ExecQuery executes a query that does not need a row returned
|
||||
func ExecQuery(q *Query) (sql.Result, error) {
|
||||
qs, args := buildQuery(q)
|
||||
if DebugMode {
|
||||
fmt.Fprintln(DebugWriter, qs)
|
||||
}
|
||||
return q.executor.Exec(qs, args...)
|
||||
}
|
||||
|
||||
// ExecQueryOne executes the query for the One finisher and returns a row
|
||||
func ExecQueryOne(q *Query) *sql.Row {
|
||||
qs, args := buildQuery(q)
|
||||
if DebugMode {
|
||||
fmt.Fprintln(DebugWriter, qs)
|
||||
}
|
||||
return q.executor.QueryRow(qs, args)
|
||||
}
|
||||
|
||||
// ExecQueryAll executes the query for the All finisher and returns multiple rows
|
||||
func ExecQueryAll(q *Query) (*sql.Rows, error) {
|
||||
qs, args := buildQuery(q)
|
||||
if DebugMode {
|
||||
fmt.Fprintln(DebugWriter, qs)
|
||||
}
|
||||
return q.executor.Query(qs, args)
|
||||
}
|
||||
|
||||
|
|
|
@ -3,18 +3,29 @@ package boil
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"reflect"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/guregu/null"
|
||||
"github.com/pobri19/sqlboiler/strmangle"
|
||||
"gopkg.in/BlackBaronsTux/null-extended.v1"
|
||||
)
|
||||
|
||||
var (
|
||||
typeNullFloat32 = reflect.TypeOf(null.Float32{})
|
||||
typeNullFloat64 = reflect.TypeOf(null.Float64{})
|
||||
typeNullInt = reflect.TypeOf(null.Int{})
|
||||
typeNullFloat = reflect.TypeOf(null.Float{})
|
||||
typeNullInt8 = reflect.TypeOf(null.Int8{})
|
||||
typeNullInt16 = reflect.TypeOf(null.Int16{})
|
||||
typeNullInt32 = reflect.TypeOf(null.Int32{})
|
||||
typeNullInt64 = reflect.TypeOf(null.Int64{})
|
||||
typeNullUint = reflect.TypeOf(null.Uint{})
|
||||
typeNullUint8 = reflect.TypeOf(null.Uint8{})
|
||||
typeNullUint16 = reflect.TypeOf(null.Uint16{})
|
||||
typeNullUint32 = reflect.TypeOf(null.Uint32{})
|
||||
typeNullUint64 = reflect.TypeOf(null.Uint64{})
|
||||
typeNullString = reflect.TypeOf(null.String{})
|
||||
typeNullBool = reflect.TypeOf(null.Bool{})
|
||||
typeNullTime = reflect.TypeOf(null.Time{})
|
||||
|
@ -163,27 +174,65 @@ func randomizeField(field reflect.Value) error {
|
|||
|
||||
if kind == reflect.Struct {
|
||||
switch typ {
|
||||
case typeNullInt:
|
||||
newVal = null.NewInt(rand.Int63(), rand.Intn(2) == 1)
|
||||
case typeNullFloat:
|
||||
newVal = null.NewFloat(rand.Float64(), rand.Intn(2) == 1)
|
||||
case typeNullBool:
|
||||
newVal = null.NewBool(rand.Intn(2) == 1, rand.Intn(2) == 1)
|
||||
case typeNullString:
|
||||
newVal = null.NewString(randStr(5+rand.Intn(25)), rand.Intn(2) == 1)
|
||||
case typeNullTime:
|
||||
newVal = null.NewTime(time.Unix(rand.Int63(), 0), rand.Intn(2) == 1)
|
||||
newVal = null.NewTime(time.Now().Add(time.Duration(rand.Intn((int(time.Hour * 24 * 10))))), rand.Intn(2) == 1)
|
||||
case typeTime:
|
||||
newVal = time.Unix(rand.Int63(), 0)
|
||||
newVal = time.Now().Add(time.Duration(rand.Intn((int(time.Hour * 24 * 10)))))
|
||||
case typeNullFloat32:
|
||||
newVal = null.NewFloat32(rand.Float32(), rand.Intn(2) == 1)
|
||||
case typeNullFloat64:
|
||||
newVal = null.NewFloat64(rand.Float64(), rand.Intn(2) == 1)
|
||||
case typeNullInt:
|
||||
newVal = null.NewInt(rand.Int(), rand.Intn(2) == 1)
|
||||
case typeNullInt8:
|
||||
newVal = null.NewInt8(int8(rand.Intn(int(math.MaxInt8))), rand.Intn(2) == 1)
|
||||
case typeNullInt16:
|
||||
newVal = null.NewInt16(int16(rand.Intn(int(math.MaxInt16))), rand.Intn(2) == 1)
|
||||
case typeNullInt32:
|
||||
newVal = null.NewInt32(rand.Int31(), rand.Intn(2) == 1)
|
||||
case typeNullInt64:
|
||||
newVal = null.NewInt64(rand.Int63(), rand.Intn(2) == 1)
|
||||
case typeNullUint:
|
||||
newVal = null.NewUint(uint(rand.Int()), rand.Intn(2) == 1)
|
||||
case typeNullUint8:
|
||||
newVal = null.NewUint8(uint8(rand.Intn(int(math.MaxInt8))), rand.Intn(2) == 1)
|
||||
case typeNullUint16:
|
||||
newVal = null.NewUint16(uint16(rand.Intn(int(math.MaxInt16))), rand.Intn(2) == 1)
|
||||
case typeNullUint32:
|
||||
newVal = null.NewUint32(uint32(rand.Int31()), rand.Intn(2) == 1)
|
||||
case typeNullUint64:
|
||||
newVal = null.NewUint64(uint64(rand.Int63()), rand.Intn(2) == 1)
|
||||
}
|
||||
} else {
|
||||
switch kind {
|
||||
case reflect.Int:
|
||||
newVal = rand.Int()
|
||||
case reflect.Int64:
|
||||
newVal = rand.Int63()
|
||||
case reflect.Float32:
|
||||
newVal = rand.Float32()
|
||||
case reflect.Float64:
|
||||
newVal = rand.Float64()
|
||||
case reflect.Int:
|
||||
newVal = rand.Int()
|
||||
case reflect.Int8:
|
||||
newVal = int8(rand.Intn(int(math.MaxInt8)))
|
||||
case reflect.Int16:
|
||||
newVal = int16(rand.Intn(int(math.MaxInt16)))
|
||||
case reflect.Int32:
|
||||
newVal = rand.Int31()
|
||||
case reflect.Int64:
|
||||
newVal = rand.Int63()
|
||||
case reflect.Uint:
|
||||
newVal = uint(rand.Int())
|
||||
case reflect.Uint8:
|
||||
newVal = uint8(rand.Intn(int(math.MaxInt8)))
|
||||
case reflect.Uint16:
|
||||
newVal = uint16(rand.Intn(int(math.MaxInt16)))
|
||||
case reflect.Uint32:
|
||||
newVal = uint32(rand.Int31())
|
||||
case reflect.Uint64:
|
||||
newVal = uint64(rand.Int63())
|
||||
case reflect.Bool:
|
||||
var b bool
|
||||
if rand.Intn(2) == 1 {
|
||||
|
|
|
@ -12,20 +12,50 @@ import (
|
|||
// sqlBoilerTypeImports imports are only included in the template output if the database
|
||||
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
|
||||
var sqlBoilerTypeImports = map[string]imports{
|
||||
"null.Float32": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Float64": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Int": imports{
|
||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Int8": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Int16": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Int32": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Int64": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Uint": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Uint8": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Uint16": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Uint32": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Uint64": imports{
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.String": imports{
|
||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Bool": imports{
|
||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||
},
|
||||
"null.Float": imports{
|
||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"null.Time": imports{
|
||||
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
|
||||
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
|
||||
},
|
||||
"time.Time": imports{
|
||||
standard: importList{`"time"`},
|
||||
|
@ -65,6 +95,19 @@ var sqlBoilerTestImports = imports{
|
|||
},
|
||||
}
|
||||
|
||||
var sqlBoilerSinglesTestImports = map[string]imports{
|
||||
"helper_funcs": imports{
|
||||
standard: importList{
|
||||
`"crypto/md5"`,
|
||||
`"fmt"`,
|
||||
`"os"`,
|
||||
`"strconv"`,
|
||||
`"math/rand"`,
|
||||
},
|
||||
thirdparty: importList{},
|
||||
},
|
||||
}
|
||||
|
||||
var sqlBoilerTestMainImports = map[string]imports{
|
||||
"postgres": imports{
|
||||
standard: importList{
|
||||
|
|
|
@ -111,6 +111,42 @@ func generateSinglesOutput(cmdData *CmdData) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func generateSinglesTestOutput(cmdData *CmdData) error {
|
||||
if cmdData.SingleTestTemplates == nil {
|
||||
return errors.New("No single test templates located for generation")
|
||||
}
|
||||
|
||||
tplData := &tplData{
|
||||
PkgName: cmdData.PkgName,
|
||||
DriverName: cmdData.DriverName,
|
||||
}
|
||||
|
||||
for _, template := range cmdData.SingleTestTemplates {
|
||||
var imps imports
|
||||
|
||||
resp, err := generateTemplate(template, tplData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
|
||||
}
|
||||
|
||||
fName := template.Name()
|
||||
ext := filepath.Ext(fName)
|
||||
fName = fName[0 : len(fName)-len(ext)]
|
||||
|
||||
imps.standard = sqlBoilerSinglesTestImports[fName].standard
|
||||
imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty
|
||||
|
||||
fName = fName + "_test.go"
|
||||
|
||||
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func generateTestMainOutput(cmdData *CmdData) error {
|
||||
if cmdData.TestMainTemplate == nil {
|
||||
return errors.New("No TestMain template located for generation")
|
||||
|
@ -122,9 +158,15 @@ func generateTestMainOutput(cmdData *CmdData) error {
|
|||
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard
|
||||
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty
|
||||
|
||||
var tables []string
|
||||
for _, v := range cmdData.Tables {
|
||||
tables = append(tables, v.Name)
|
||||
}
|
||||
|
||||
tplData := &tplData{
|
||||
PkgName: cmdData.PkgName,
|
||||
DriverName: cmdData.DriverName,
|
||||
Tables: tables,
|
||||
}
|
||||
|
||||
resp, err := generateTemplate(cmdData.TestMainTemplate, tplData)
|
||||
|
|
|
@ -16,7 +16,9 @@ import (
|
|||
const (
|
||||
templatesDirectory = "/cmds/templates"
|
||||
templatesSinglesDirectory = "/cmds/templates/singles"
|
||||
|
||||
templatesTestDirectory = "/cmds/templates_test"
|
||||
templatesSinglesTestDirectory = "/cmds/templates_test/singles"
|
||||
templatesTestMainDirectory = "/cmds/templates_test/main_test"
|
||||
)
|
||||
|
||||
|
@ -38,6 +40,11 @@ func initTemplates(cmdData *CmdData) error {
|
|||
return err
|
||||
}
|
||||
|
||||
cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
filename := cmdData.DriverName + "_main.tpl"
|
||||
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
|
||||
if err != nil {
|
||||
|
@ -118,14 +125,18 @@ func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
|
|||
|
||||
// run executes the sqlboiler templates and outputs them to files.
|
||||
func (c *CmdData) run(includeTests bool) error {
|
||||
if err := generateSinglesOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate single templates output: %s", err)
|
||||
}
|
||||
|
||||
if includeTests {
|
||||
if err := generateTestMainOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate TestMain output: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := generateSinglesOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate single templates output: %s", err)
|
||||
if err := generateSinglesTestOutput(c); err != nil {
|
||||
return fmt.Errorf("Unable to generate single test templates output: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
for _, table := range c.Tables {
|
||||
|
|
|
@ -63,12 +63,16 @@ func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... strin
|
|||
{{else}}
|
||||
if len(returnColumns) != 0 {
|
||||
ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ","))
|
||||
err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...))
|
||||
err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...)...)
|
||||
} else {
|
||||
_, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}})
|
||||
}
|
||||
{{end}}
|
||||
|
||||
if boil.DebugMode {
|
||||
fmt.Fprintln(boil.DebugWriter, ins)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
|
||||
}
|
||||
|
|
|
@ -35,13 +35,18 @@ func (o *{{$tableNameSingular}}) UpdateAtX(exec boil.Executor, {{primaryKeyFuncS
|
|||
}
|
||||
|
||||
var err error
|
||||
var query string
|
||||
if len(whitelist) != 0 {
|
||||
query := fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{commaList .Table.PKey.Columns}}))
|
||||
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{commaList .Table.PKey.Columns}}))
|
||||
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...), {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
|
||||
} else {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build a whitelist for row: %s", err)
|
||||
}
|
||||
|
||||
if boil.DebugMode {
|
||||
fmt.Fprintln(boil.DebugWriter, query)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err)
|
||||
}
|
||||
|
|
|
@ -16,46 +16,81 @@ var testCfg *Config
|
|||
var dbConn *sql.DB
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
// Set the DebugMode to true so we can see generated sql statements
|
||||
boil.DebugMode = true
|
||||
|
||||
err := setup()
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
var err error
|
||||
|
||||
err = setup()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Unable to execute setup: %s", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
err = disableTriggers()
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to disable triggers: %s", err)
|
||||
}
|
||||
boil.SetDB(dbConn)
|
||||
code := m.Run()
|
||||
|
||||
err = teardown()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
fmt.Printf("Unable to execute teardown: %s", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// teardown switches its connection to the template1 database temporarily
|
||||
// so that it can drop the test database and the test user.
|
||||
// The template1 database should be present on all default postgres installations.
|
||||
// disableTriggers is used to disable foreign key constraints for every table.
|
||||
// If this is not used we cannot test inserts due to foreign key constraint errors.
|
||||
func disableTriggers() error {
|
||||
var stmts []string
|
||||
|
||||
{{range .Tables}}
|
||||
stmts = append(stmts, `ALTER TABLE {{.}} DISABLE TRIGGER ALL;`)
|
||||
{{- end}}
|
||||
|
||||
if len(stmts) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, s := range stmts {
|
||||
_, err = dbConn.Exec(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// teardown executes cleanup tasks when the tests finish running
|
||||
func teardown() error {
|
||||
err := dbConn.Close()
|
||||
err := dropTestDB()
|
||||
return err
|
||||
}
|
||||
|
||||
// dropTestDB switches its connection to the template1 database temporarily
|
||||
// so that it can drop the test database without causing "in use" conflicts.
|
||||
// The template1 database should be present on all default postgres installations.
|
||||
func dropTestDB() error {
|
||||
var err error
|
||||
if dbConn != nil {
|
||||
if err = dbConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
dbConn, err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, "template1", cfg.Postgres.Host, cfg.Postgres.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE %s;`, testCfg.Postgres.DBName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`DROP USER %s;`, testCfg.Postgres.User))
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -71,16 +106,6 @@ func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) {
|
|||
return sql.Open("postgres", connStr)
|
||||
}
|
||||
|
||||
func randSeq(n int) string {
|
||||
var letters = []rune("abcdefghijklmnopqrstuvwxyz")
|
||||
|
||||
randStr := make([]rune, n)
|
||||
for i := range randStr {
|
||||
randStr[i] = letters[rand.Intn(len(letters))]
|
||||
}
|
||||
return string(randStr)
|
||||
}
|
||||
|
||||
func LoadConfigFile(filename string) error {
|
||||
_, err := toml.DecodeFile(filename, &cfg)
|
||||
|
||||
|
@ -105,13 +130,21 @@ func setup() error {
|
|||
return fmt.Errorf("Unable to load config file: %s", err)
|
||||
}
|
||||
|
||||
testDBName := getDBNameHash(cfg.Postgres.DBName)
|
||||
|
||||
// Create a randomized test configuration object.
|
||||
testCfg = &Config{}
|
||||
testCfg.Postgres.Host = cfg.Postgres.Host
|
||||
testCfg.Postgres.Port = cfg.Postgres.Port
|
||||
testCfg.Postgres.User = randSeq(20)
|
||||
testCfg.Postgres.Pass = randSeq(20)
|
||||
testCfg.Postgres.DBName = cfg.Postgres.DBName + "_" + randSeq(10)
|
||||
testCfg.Postgres.User = cfg.Postgres.User
|
||||
testCfg.Postgres.Pass = cfg.Postgres.Pass
|
||||
testCfg.Postgres.DBName = testDBName
|
||||
|
||||
err = dropTestDB()
|
||||
if err != nil {
|
||||
fmt.Printf("%#v\n", err)
|
||||
return err
|
||||
}
|
||||
|
||||
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema")
|
||||
if err != nil {
|
||||
|
@ -166,35 +199,24 @@ func setup() error {
|
|||
return err
|
||||
}
|
||||
|
||||
// Create the randomly generated database test user
|
||||
if err = createTestUser(dbConn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Create the randomly generated database
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Assign the randomly generated db test user to the generated test db
|
||||
_, err = dbConn.Exec(fmt.Sprintf(`ALTER DATABASE %s OWNER TO %s;`, testCfg.Postgres.DBName, testCfg.Postgres.User))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Close the old connection so we can reconnect with the restricted access generated user
|
||||
// Close the old connection so we can reconnect to the test database
|
||||
if err = dbConn.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Connect to the generated test db with the restricted privilege generated user
|
||||
// Connect to the generated test db
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the generated user password to a tmp file for pg_dump
|
||||
// Write the test config credentials to a tmp file for pg_dump
|
||||
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
|
||||
testCfg.Postgres.Host,
|
||||
testCfg.Postgres.Port,
|
||||
|
@ -233,18 +255,3 @@ func setup() error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createTestUser creates a temporary database user with restricted privileges
|
||||
func createTestUser(db *sql.DB) error {
|
||||
now := time.Now().Add(time.Hour * 24 * 2)
|
||||
valid := now.Format("2006-1-2")
|
||||
|
||||
query := fmt.Sprintf(`CREATE USER %s WITH PASSWORD '%s' VALID UNTIL '%s';`,
|
||||
testCfg.Postgres.User,
|
||||
testCfg.Postgres.Pass,
|
||||
valid,
|
||||
)
|
||||
|
||||
_, err := dbConn.Exec(query)
|
||||
return err
|
||||
}
|
||||
|
|
53
cmds/templates_test/singles/helper_funcs.tpl
Normal file
53
cmds/templates_test/singles/helper_funcs.tpl
Normal file
|
@ -0,0 +1,53 @@
|
|||
var dbNameRand *rand.Rand
|
||||
|
||||
func initDBNameRand(input string) {
|
||||
sum := md5.Sum([]byte(input))
|
||||
|
||||
var sumInt string
|
||||
for _, v := range sum {
|
||||
sumInt = sumInt + strconv.Itoa(int(v))
|
||||
}
|
||||
|
||||
// Cut integer to 18 digits to ensure no int64 overflow.
|
||||
sumInt = sumInt[:18]
|
||||
|
||||
sumTmp := sumInt
|
||||
for i, v := range sumInt {
|
||||
if v == '0' {
|
||||
sumTmp = sumInt[i+1:]
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
sumInt = sumTmp
|
||||
|
||||
randSeed, err := strconv.ParseInt(sumInt, 0, 64)
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to parse sumInt: %s", err)
|
||||
os.Exit(-1)
|
||||
}
|
||||
|
||||
dbNameRand = rand.New(rand.NewSource(randSeed))
|
||||
}
|
||||
|
||||
var alphabetChars = "abcdefghijklmnopqrstuvwxyz"
|
||||
func randStr(length int) string {
|
||||
c := len(alphabetChars)
|
||||
|
||||
output := make([]rune, length)
|
||||
for i := 0; i < length; i++ {
|
||||
output[i] = rune(alphabetChars[dbNameRand.Intn(c)])
|
||||
}
|
||||
|
||||
return string(output)
|
||||
}
|
||||
|
||||
// getDBNameHash takes a database name in, and generates
|
||||
// a random string using the database name as the rand Seed.
|
||||
// getDBNameHash is used to generate unique test database names.
|
||||
func getDBNameHash(input string) string {
|
||||
initDBNameRand(input)
|
||||
return randStr(40)
|
||||
}
|
|
@ -29,6 +29,9 @@ type CmdData struct {
|
|||
SingleTemplates templater
|
||||
|
||||
TestTemplates templater
|
||||
// SingleTestTemplates are only created once, not per table
|
||||
SingleTestTemplates templater
|
||||
//TestMainTemplate is only created once, not per table
|
||||
TestMainTemplate *template.Template
|
||||
}
|
||||
|
||||
|
@ -37,6 +40,7 @@ type tplData struct {
|
|||
Table dbdrivers.Table
|
||||
DriverName string
|
||||
PkgName string
|
||||
Tables []string
|
||||
}
|
||||
|
||||
type importList []string
|
||||
|
|
|
@ -210,23 +210,37 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error)
|
|||
func (p *PostgresDriver) TranslateColumnType(c Column) Column {
|
||||
if c.IsNullable {
|
||||
switch c.Type {
|
||||
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
||||
c.Type = "null.Int"
|
||||
case "bigint", "bigserial":
|
||||
c.Type = "null.Int64"
|
||||
case "integer", "serial":
|
||||
c.Type = "null.Int32"
|
||||
case "smallint", "smallserial":
|
||||
c.Type = "null.Int16"
|
||||
case "decimal", "numeric", "double precision", "money":
|
||||
c.Type = "null.Float64"
|
||||
case "real":
|
||||
c.Type = "null.Float32"
|
||||
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
||||
c.Type = "null.String"
|
||||
case "boolean":
|
||||
c.Type = "null.Bool"
|
||||
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
||||
c.Type = "null.Time"
|
||||
case "double precision", "money", "numeric", "real":
|
||||
c.Type = "null.Float"
|
||||
default:
|
||||
c.Type = "null.String"
|
||||
}
|
||||
} else {
|
||||
switch c.Type {
|
||||
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
|
||||
case "bigint", "bigserial":
|
||||
c.Type = "int64"
|
||||
case "integer", "serial":
|
||||
c.Type = "int32"
|
||||
case "smallint", "smallserial":
|
||||
c.Type = "int16"
|
||||
case "decimal", "numeric", "double precision", "money":
|
||||
c.Type = "float64"
|
||||
case "real":
|
||||
c.Type = "float32"
|
||||
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
|
||||
c.Type = "string"
|
||||
case "bytea":
|
||||
|
@ -235,8 +249,6 @@ func (p *PostgresDriver) TranslateColumnType(c Column) Column {
|
|||
c.Type = "bool"
|
||||
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
|
||||
c.Type = "time.Time"
|
||||
case "double precision", "money", "numeric", "real":
|
||||
c.Type = "float64"
|
||||
default:
|
||||
c.Type = "string"
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue