Added debug mode, test singles

* Where statement OR/AND support
* Added null-extended library to use different data types
* Added disable triggers function for test main
This commit is contained in:
Patrick O'brien 2016-06-03 07:07:51 +10:00
parent c2541ea56e
commit 3152eed170
13 changed files with 386 additions and 105 deletions

View file

@ -1,6 +1,9 @@
package boil
import "database/sql"
import (
"database/sql"
"os"
)
type Executor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
@ -21,6 +24,15 @@ type Creator interface {
var currentDB Executor
// DebugMode is a flag controlling whether generated sql statements and
// debug information is outputted to the DebugWriter handle
//
// NOTE: This should be disabled in production to avoid leaking sensitive data
var DebugMode = false
// DebugWriter is where the debug output will be sent if DebugMode is true
var DebugWriter = os.Stdout
func Begin() (Transactor, error) {
creator, ok := currentDB.(Creator)
if !ok {

View file

@ -10,6 +10,15 @@ func Apply(q *boil.Query, mods ...QueryMod) {
}
}
func Or(whereMods ...QueryMod) QueryMod {
return func(q *boil.Query) {
if len(whereMods) < 2 {
// error, needs to be at least 2 for an or
}
// add the where mods to query with or seperators
}
}
func Limit(limit int) QueryMod {
return func(q *boil.Query) {
boil.SetLimit(q, limit)

View file

@ -8,8 +8,9 @@ import (
)
type where struct {
clause string
args []interface{}
clause string
orSeperator bool
args []interface{}
}
type join struct {
@ -65,36 +66,65 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf.WriteString(" FROM ")
fmt.Fprintf(buf, `"%s"`, q.table)
buf.WriteByte(';')
return buf, []interface{}{}
}
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ")
fmt.Fprintf(buf, `"%s"`, q.table)
if len(q.where) > 0 {
for i := 0; i < len(q.where); i++ {
buf.WriteString(fmt.Sprintf(` WHERE %s`, q.where[i].clause))
if i != len(q.where)-1 {
if q.where[i].orSeperator {
buf.WriteString(` OR `)
} else {
buf.WriteString(` AND `)
}
}
}
}
buf.WriteByte(';')
return buf, nil
}
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteByte(';')
return buf, nil
}
// ExecQuery executes a query that does not need a row returned
func ExecQuery(q *Query) (sql.Result, error) {
qs, args := buildQuery(q)
if DebugMode {
fmt.Fprintln(DebugWriter, qs)
}
return q.executor.Exec(qs, args...)
}
// ExecQueryOne executes the query for the One finisher and returns a row
func ExecQueryOne(q *Query) *sql.Row {
qs, args := buildQuery(q)
if DebugMode {
fmt.Fprintln(DebugWriter, qs)
}
return q.executor.QueryRow(qs, args)
}
// ExecQueryAll executes the query for the All finisher and returns multiple rows
func ExecQueryAll(q *Query) (*sql.Rows, error) {
qs, args := buildQuery(q)
if DebugMode {
fmt.Fprintln(DebugWriter, qs)
}
return q.executor.Query(qs, args)
}

View file

@ -3,22 +3,33 @@ package boil
import (
"database/sql"
"fmt"
"math"
"math/rand"
"reflect"
"sort"
"time"
"github.com/guregu/null"
"github.com/pobri19/sqlboiler/strmangle"
"gopkg.in/BlackBaronsTux/null-extended.v1"
)
var (
typeNullInt = reflect.TypeOf(null.Int{})
typeNullFloat = reflect.TypeOf(null.Float{})
typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeTime = reflect.TypeOf(time.Time{})
typeNullFloat32 = reflect.TypeOf(null.Float32{})
typeNullFloat64 = reflect.TypeOf(null.Float64{})
typeNullInt = reflect.TypeOf(null.Int{})
typeNullInt8 = reflect.TypeOf(null.Int8{})
typeNullInt16 = reflect.TypeOf(null.Int16{})
typeNullInt32 = reflect.TypeOf(null.Int32{})
typeNullInt64 = reflect.TypeOf(null.Int64{})
typeNullUint = reflect.TypeOf(null.Uint{})
typeNullUint8 = reflect.TypeOf(null.Uint8{})
typeNullUint16 = reflect.TypeOf(null.Uint16{})
typeNullUint32 = reflect.TypeOf(null.Uint32{})
typeNullUint64 = reflect.TypeOf(null.Uint64{})
typeNullString = reflect.TypeOf(null.String{})
typeNullBool = reflect.TypeOf(null.Bool{})
typeNullTime = reflect.TypeOf(null.Time{})
typeTime = reflect.TypeOf(time.Time{})
)
// Bind executes the query and inserts the
@ -163,27 +174,65 @@ func randomizeField(field reflect.Value) error {
if kind == reflect.Struct {
switch typ {
case typeNullInt:
newVal = null.NewInt(rand.Int63(), rand.Intn(2) == 1)
case typeNullFloat:
newVal = null.NewFloat(rand.Float64(), rand.Intn(2) == 1)
case typeNullBool:
newVal = null.NewBool(rand.Intn(2) == 1, rand.Intn(2) == 1)
case typeNullString:
newVal = null.NewString(randStr(5+rand.Intn(25)), rand.Intn(2) == 1)
case typeNullTime:
newVal = null.NewTime(time.Unix(rand.Int63(), 0), rand.Intn(2) == 1)
newVal = null.NewTime(time.Now().Add(time.Duration(rand.Intn((int(time.Hour * 24 * 10))))), rand.Intn(2) == 1)
case typeTime:
newVal = time.Unix(rand.Int63(), 0)
newVal = time.Now().Add(time.Duration(rand.Intn((int(time.Hour * 24 * 10)))))
case typeNullFloat32:
newVal = null.NewFloat32(rand.Float32(), rand.Intn(2) == 1)
case typeNullFloat64:
newVal = null.NewFloat64(rand.Float64(), rand.Intn(2) == 1)
case typeNullInt:
newVal = null.NewInt(rand.Int(), rand.Intn(2) == 1)
case typeNullInt8:
newVal = null.NewInt8(int8(rand.Intn(int(math.MaxInt8))), rand.Intn(2) == 1)
case typeNullInt16:
newVal = null.NewInt16(int16(rand.Intn(int(math.MaxInt16))), rand.Intn(2) == 1)
case typeNullInt32:
newVal = null.NewInt32(rand.Int31(), rand.Intn(2) == 1)
case typeNullInt64:
newVal = null.NewInt64(rand.Int63(), rand.Intn(2) == 1)
case typeNullUint:
newVal = null.NewUint(uint(rand.Int()), rand.Intn(2) == 1)
case typeNullUint8:
newVal = null.NewUint8(uint8(rand.Intn(int(math.MaxInt8))), rand.Intn(2) == 1)
case typeNullUint16:
newVal = null.NewUint16(uint16(rand.Intn(int(math.MaxInt16))), rand.Intn(2) == 1)
case typeNullUint32:
newVal = null.NewUint32(uint32(rand.Int31()), rand.Intn(2) == 1)
case typeNullUint64:
newVal = null.NewUint64(uint64(rand.Int63()), rand.Intn(2) == 1)
}
} else {
switch kind {
case reflect.Int:
newVal = rand.Int()
case reflect.Int64:
newVal = rand.Int63()
case reflect.Float32:
newVal = rand.Float32()
case reflect.Float64:
newVal = rand.Float64()
case reflect.Int:
newVal = rand.Int()
case reflect.Int8:
newVal = int8(rand.Intn(int(math.MaxInt8)))
case reflect.Int16:
newVal = int16(rand.Intn(int(math.MaxInt16)))
case reflect.Int32:
newVal = rand.Int31()
case reflect.Int64:
newVal = rand.Int63()
case reflect.Uint:
newVal = uint(rand.Int())
case reflect.Uint8:
newVal = uint8(rand.Intn(int(math.MaxInt8)))
case reflect.Uint16:
newVal = uint16(rand.Intn(int(math.MaxInt16)))
case reflect.Uint32:
newVal = uint32(rand.Int31())
case reflect.Uint64:
newVal = uint64(rand.Int63())
case reflect.Bool:
var b bool
if rand.Intn(2) == 1 {

View file

@ -12,20 +12,50 @@ import (
// sqlBoilerTypeImports imports are only included in the template output if the database
// requires one of the following special types. Check TranslateColumnType to see the type assignments.
var sqlBoilerTypeImports = map[string]imports{
"null.Float32": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Float64": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Int": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Int8": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Int16": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Int32": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Int64": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Uint": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Uint8": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Uint16": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Uint32": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Uint64": imports{
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.String": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Bool": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
},
"null.Float": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"null.Time": imports{
thirdparty: importList{`"gopkg.in/guregu/null.v3"`},
thirdparty: importList{`"gopkg.in/BlackBaronsTux/null-extended.v1"`},
},
"time.Time": imports{
standard: importList{`"time"`},
@ -65,6 +95,19 @@ var sqlBoilerTestImports = imports{
},
}
var sqlBoilerSinglesTestImports = map[string]imports{
"helper_funcs": imports{
standard: importList{
`"crypto/md5"`,
`"fmt"`,
`"os"`,
`"strconv"`,
`"math/rand"`,
},
thirdparty: importList{},
},
}
var sqlBoilerTestMainImports = map[string]imports{
"postgres": imports{
standard: importList{

View file

@ -111,6 +111,42 @@ func generateSinglesOutput(cmdData *CmdData) error {
return nil
}
func generateSinglesTestOutput(cmdData *CmdData) error {
if cmdData.SingleTestTemplates == nil {
return errors.New("No single test templates located for generation")
}
tplData := &tplData{
PkgName: cmdData.PkgName,
DriverName: cmdData.DriverName,
}
for _, template := range cmdData.SingleTestTemplates {
var imps imports
resp, err := generateTemplate(template, tplData)
if err != nil {
return fmt.Errorf("Error generating test template %s: %s", template.Name(), err)
}
fName := template.Name()
ext := filepath.Ext(fName)
fName = fName[0 : len(fName)-len(ext)]
imps.standard = sqlBoilerSinglesTestImports[fName].standard
imps.thirdparty = sqlBoilerSinglesTestImports[fName].thirdparty
fName = fName + "_test.go"
err = outHandler(cmdData.OutFolder, fName, cmdData.PkgName, imps, [][]byte{resp})
if err != nil {
return err
}
}
return nil
}
func generateTestMainOutput(cmdData *CmdData) error {
if cmdData.TestMainTemplate == nil {
return errors.New("No TestMain template located for generation")
@ -122,9 +158,15 @@ func generateTestMainOutput(cmdData *CmdData) error {
imps.standard = sqlBoilerTestMainImports[cmdData.DriverName].standard
imps.thirdparty = sqlBoilerTestMainImports[cmdData.DriverName].thirdparty
var tables []string
for _, v := range cmdData.Tables {
tables = append(tables, v.Name)
}
tplData := &tplData{
PkgName: cmdData.PkgName,
DriverName: cmdData.DriverName,
Tables: tables,
}
resp, err := generateTemplate(cmdData.TestMainTemplate, tplData)

View file

@ -14,10 +14,12 @@ import (
)
const (
templatesDirectory = "/cmds/templates"
templatesSinglesDirectory = "/cmds/templates/singles"
templatesTestDirectory = "/cmds/templates_test"
templatesTestMainDirectory = "/cmds/templates_test/main_test"
templatesDirectory = "/cmds/templates"
templatesSinglesDirectory = "/cmds/templates/singles"
templatesTestDirectory = "/cmds/templates_test"
templatesSinglesTestDirectory = "/cmds/templates_test/singles"
templatesTestMainDirectory = "/cmds/templates_test/main_test"
)
// LoadTemplates loads all template folders into the cmdData object.
@ -38,6 +40,11 @@ func initTemplates(cmdData *CmdData) error {
return err
}
cmdData.SingleTestTemplates, err = loadTemplates(templatesSinglesTestDirectory)
if err != nil {
return err
}
filename := cmdData.DriverName + "_main.tpl"
cmdData.TestMainTemplate, err = loadTemplate(templatesTestMainDirectory, filename)
if err != nil {
@ -118,14 +125,18 @@ func (c *CmdData) SQLBoilerRun(cmd *cobra.Command, args []string) error {
// run executes the sqlboiler templates and outputs them to files.
func (c *CmdData) run(includeTests bool) error {
if err := generateSinglesOutput(c); err != nil {
return fmt.Errorf("Unable to generate single templates output: %s", err)
}
if includeTests {
if err := generateTestMainOutput(c); err != nil {
return fmt.Errorf("Unable to generate TestMain output: %s", err)
}
}
if err := generateSinglesOutput(c); err != nil {
return fmt.Errorf("Unable to generate single templates output: %s", err)
if err := generateSinglesTestOutput(c); err != nil {
return fmt.Errorf("Unable to generate single test templates output: %s", err)
}
}
for _, table := range c.Tables {

View file

@ -63,12 +63,16 @@ func (o *{{$tableNameSingular}}) InsertX(exec boil.Executor, whitelist ... strin
{{else}}
if len(returnColumns) != 0 {
ins = ins + fmt.Sprintf(` RETURNING %s`, strings.Join(returnColumns, ","))
err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...))
err = exec.QueryRow(ins, boil.GetStructValues(o, wl...)...).Scan(boil.GetStructPointers(o, returnColumns...)...)
} else {
_, err = exec.Exec(ins, {{insertParamVariables "o." .Table.Columns}})
}
{{end}}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, ins)
}
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to insert into {{.Table.Name}}: %s", err)
}

View file

@ -35,13 +35,18 @@ func (o *{{$tableNameSingular}}) UpdateAtX(exec boil.Executor, {{primaryKeyFuncS
}
var err error
var query string
if len(whitelist) != 0 {
query := fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{commaList .Table.PKey.Columns}}))
query = fmt.Sprintf(`UPDATE {{.Table.Name}} SET %s WHERE %s`, boil.SetParamNames(whitelist), boil.WherePrimaryKey(len(whitelist)+1, {{commaList .Table.PKey.Columns}}))
_, err = exec.Exec(query, boil.GetStructValues(o, whitelist...), {{paramsPrimaryKey "o." .Table.PKey.Columns true}})
} else {
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}}, could not build a whitelist for row: %s", err)
}
if boil.DebugMode {
fmt.Fprintln(boil.DebugWriter, query)
}
if err != nil {
return fmt.Errorf("{{.PkgName}}: unable to update {{.Table.Name}} row: %s", err)
}

View file

@ -16,46 +16,81 @@ var testCfg *Config
var dbConn *sql.DB
func TestMain(m *testing.M) {
rand.Seed(time.Now().UnixNano())
// Set the DebugMode to true so we can see generated sql statements
boil.DebugMode = true
err := setup()
rand.Seed(time.Now().UnixNano())
var err error
err = setup()
if err != nil {
fmt.Println(err)
fmt.Printf("Unable to execute setup: %s", err)
os.Exit(-1)
}
err = disableTriggers()
if err != nil {
fmt.Printf("Unable to disable triggers: %s", err)
}
boil.SetDB(dbConn)
code := m.Run()
err = teardown()
if err != nil {
fmt.Println(err)
fmt.Printf("Unable to execute teardown: %s", err)
os.Exit(-1)
}
os.Exit(code)
}
// teardown switches its connection to the template1 database temporarily
// so that it can drop the test database and the test user.
// The template1 database should be present on all default postgres installations.
// disableTriggers is used to disable foreign key constraints for every table.
// If this is not used we cannot test inserts due to foreign key constraint errors.
func disableTriggers() error {
var stmts []string
{{range .Tables}}
stmts = append(stmts, `ALTER TABLE {{.}} DISABLE TRIGGER ALL;`)
{{- end}}
if len(stmts) == 0 {
return nil
}
var err error
for _, s := range stmts {
_, err = dbConn.Exec(s)
if err != nil {
return err
}
}
return nil
}
// teardown executes cleanup tasks when the tests finish running
func teardown() error {
err := dbConn.Close()
err := dropTestDB()
return err
}
// dropTestDB switches its connection to the template1 database temporarily
// so that it can drop the test database without causing "in use" conflicts.
// The template1 database should be present on all default postgres installations.
func dropTestDB() error {
var err error
if dbConn != nil {
if err = dbConn.Close(); err != nil {
return err
}
}
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port)
if err != nil {
return err
}
dbConn, err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, "template1", cfg.Postgres.Host, cfg.Postgres.Port)
if err != nil {
return err
}
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE %s;`, testCfg.Postgres.DBName))
if err != nil {
return err
}
_, err = dbConn.Exec(fmt.Sprintf(`DROP USER %s;`, testCfg.Postgres.User))
_, err = dbConn.Exec(fmt.Sprintf(`DROP DATABASE IF EXISTS %s;`, testCfg.Postgres.DBName))
if err != nil {
return err
}
@ -71,16 +106,6 @@ func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) {
return sql.Open("postgres", connStr)
}
func randSeq(n int) string {
var letters = []rune("abcdefghijklmnopqrstuvwxyz")
randStr := make([]rune, n)
for i := range randStr {
randStr[i] = letters[rand.Intn(len(letters))]
}
return string(randStr)
}
func LoadConfigFile(filename string) error {
_, err := toml.DecodeFile(filename, &cfg)
@ -105,13 +130,21 @@ func setup() error {
return fmt.Errorf("Unable to load config file: %s", err)
}
testDBName := getDBNameHash(cfg.Postgres.DBName)
// Create a randomized test configuration object.
testCfg = &Config{}
testCfg.Postgres.Host = cfg.Postgres.Host
testCfg.Postgres.Port = cfg.Postgres.Port
testCfg.Postgres.User = randSeq(20)
testCfg.Postgres.Pass = randSeq(20)
testCfg.Postgres.DBName = cfg.Postgres.DBName + "_" + randSeq(10)
testCfg.Postgres.User = cfg.Postgres.User
testCfg.Postgres.Pass = cfg.Postgres.Pass
testCfg.Postgres.DBName = testDBName
err = dropTestDB()
if err != nil {
fmt.Printf("%#v\n", err)
return err
}
fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema")
if err != nil {
@ -166,35 +199,24 @@ func setup() error {
return err
}
// Create the randomly generated database test user
if err = createTestUser(dbConn); err != nil {
return err
}
// Create the randomly generated database
_, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, testCfg.Postgres.DBName))
if err != nil {
return err
}
// Assign the randomly generated db test user to the generated test db
_, err = dbConn.Exec(fmt.Sprintf(`ALTER DATABASE %s OWNER TO %s;`, testCfg.Postgres.DBName, testCfg.Postgres.User))
if err != nil {
return err
}
// Close the old connection so we can reconnect with the restricted access generated user
// Close the old connection so we can reconnect to the test database
if err = dbConn.Close(); err != nil {
return err
}
// Connect to the generated test db with the restricted privilege generated user
// Connect to the generated test db
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port)
if err != nil {
return err
}
// Write the generated user password to a tmp file for pg_dump
// Write the test config credentials to a tmp file for pg_dump
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
testCfg.Postgres.Host,
testCfg.Postgres.Port,
@ -233,18 +255,3 @@ func setup() error {
return nil
}
// createTestUser creates a temporary database user with restricted privileges
func createTestUser(db *sql.DB) error {
now := time.Now().Add(time.Hour * 24 * 2)
valid := now.Format("2006-1-2")
query := fmt.Sprintf(`CREATE USER %s WITH PASSWORD '%s' VALID UNTIL '%s';`,
testCfg.Postgres.User,
testCfg.Postgres.Pass,
valid,
)
_, err := dbConn.Exec(query)
return err
}

View file

@ -0,0 +1,53 @@
var dbNameRand *rand.Rand
func initDBNameRand(input string) {
sum := md5.Sum([]byte(input))
var sumInt string
for _, v := range sum {
sumInt = sumInt + strconv.Itoa(int(v))
}
// Cut integer to 18 digits to ensure no int64 overflow.
sumInt = sumInt[:18]
sumTmp := sumInt
for i, v := range sumInt {
if v == '0' {
sumTmp = sumInt[i+1:]
continue
}
break
}
sumInt = sumTmp
randSeed, err := strconv.ParseInt(sumInt, 0, 64)
if err != nil {
fmt.Printf("Unable to parse sumInt: %s", err)
os.Exit(-1)
}
dbNameRand = rand.New(rand.NewSource(randSeed))
}
var alphabetChars = "abcdefghijklmnopqrstuvwxyz"
func randStr(length int) string {
c := len(alphabetChars)
output := make([]rune, length)
for i := 0; i < length; i++ {
output[i] = rune(alphabetChars[dbNameRand.Intn(c)])
}
return string(output)
}
// getDBNameHash takes a database name in, and generates
// a random string using the database name as the rand Seed.
// getDBNameHash is used to generate unique test database names.
func getDBNameHash(input string) string {
initDBNameRand(input)
return randStr(40)
}

View file

@ -28,7 +28,10 @@ type CmdData struct {
// SingleTemplates are only created once, not per table
SingleTemplates templater
TestTemplates templater
TestTemplates templater
// SingleTestTemplates are only created once, not per table
SingleTestTemplates templater
//TestMainTemplate is only created once, not per table
TestMainTemplate *template.Template
}
@ -37,6 +40,7 @@ type tplData struct {
Table dbdrivers.Table
DriverName string
PkgName string
Tables []string
}
type importList []string

View file

@ -210,23 +210,37 @@ func (p *PostgresDriver) ForeignKeyInfo(tableName string) ([]ForeignKey, error)
func (p *PostgresDriver) TranslateColumnType(c Column) Column {
if c.IsNullable {
switch c.Type {
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
c.Type = "null.Int"
case "bigint", "bigserial":
c.Type = "null.Int64"
case "integer", "serial":
c.Type = "null.Int32"
case "smallint", "smallserial":
c.Type = "null.Int16"
case "decimal", "numeric", "double precision", "money":
c.Type = "null.Float64"
case "real":
c.Type = "null.Float32"
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
c.Type = "null.String"
case "boolean":
c.Type = "null.Bool"
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "null.Time"
case "double precision", "money", "numeric", "real":
c.Type = "null.Float"
default:
c.Type = "null.String"
}
} else {
switch c.Type {
case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":
case "bigint", "bigserial":
c.Type = "int64"
case "integer", "serial":
c.Type = "int32"
case "smallint", "smallserial":
c.Type = "int16"
case "decimal", "numeric", "double precision", "money":
c.Type = "float64"
case "real":
c.Type = "float32"
case "bit", "bit varying", "character", "character varying", "cidr", "inet", "json", "macaddr", "text", "uuid", "xml":
c.Type = "string"
case "bytea":
@ -235,8 +249,6 @@ func (p *PostgresDriver) TranslateColumnType(c Column) Column {
c.Type = "bool"
case "date", "interval", "time", "timestamp without time zone", "timestamp with time zone":
c.Type = "time.Time"
case "double precision", "money", "numeric", "real":
c.Type = "float64"
default:
c.Type = "string"
}