diff --git a/cmds/boil.go b/cmds/boil.go index 6deae25..73d87be 100644 --- a/cmds/boil.go +++ b/cmds/boil.go @@ -18,9 +18,9 @@ func boilRun(cmd *cobra.Command, args []string) { // Prepend "struct" command to templateNames slice so it sits at top of sort commandNames = append([]string{"struct"}, commandNames...) - // Create a testCommandNames with "main" prepended to the front for the test templates + // Create a testCommandNames with "driverName_main" prepended to the front for the test templates // the main template initializes all of the testing assets - testCommandNames := append([]string{"main"}, commandNames...) + testCommandNames := append([]string{cmdData.DriverName + "_main"}, commandNames...) for _, table := range cmdData.Tables { data := &tplData{ @@ -54,6 +54,8 @@ func boilRun(cmd *cobra.Command, args []string) { testImps.standard = sqlBoilerDefaultTestImports.standard testImps.thirdparty = sqlBoilerDefaultTestImports.thirdparty + testImps = combineImports(testImps, sqlBoilerConditionalDriverTestImports[cmdData.DriverName]) + // Loop through and generate every command test template (excluding skipTemplates) for _, command := range testCommandNames { testImps = combineImports(testImps, sqlBoilerCustomTestImports[command]) diff --git a/cmds/commands.go b/cmds/commands.go index d4a9e33..21845aa 100644 --- a/cmds/commands.go +++ b/cmds/commands.go @@ -30,6 +30,15 @@ var sqlBoilerDefaultImports = imports{ var sqlBoilerDefaultTestImports = imports{ standard: importList{ `"testing"`, + `"os"`, + `"os/exec"`, + `"fmt"`, + `"io/ioutil"`, + `"bytes"`, + `"errors"`, + }, + thirdparty: importList{ + `"github.com/BurntSushi/toml"`, }, } @@ -57,6 +66,15 @@ var sqlBoilerConditionalTypeImports = map[string]imports{ }, } +// sqlBoilerConditionalDriverTestImports defines the test template imports +// for the particular database interfaces +var sqlBoilerConditionalDriverTestImports = map[string]imports{ + "postgres": imports{ + standard: importList{`"database/sql"`}, + thirdparty: importList{`_ "github.com/lib/pq"`}, + }, +} + var sqlBoilerCustomImports map[string]imports var sqlBoilerCustomTestImports map[string]imports diff --git a/cmds/config.go b/cmds/config.go index 72cefa1..c4cca6d 100644 --- a/cmds/config.go +++ b/cmds/config.go @@ -16,13 +16,7 @@ type PostgresCfg struct { } type Config struct { - Postgres PostgresCfg `toml:"postgres"` - // The TestPostgres object holds the configuration pointing to a test database. - // This test database is used to test all of the commands that have an accompanying - // command.testtpl file. These template files generate the go test functions. - // - // Note: These test templates will only be generated for the boil command, - // if an OutFolder is provided, and if all test config variables are present. + Postgres PostgresCfg `toml:"postgres"` TestPostgres *PostgresCfg `toml:"postgres_test"` } diff --git a/cmds/shared.go b/cmds/shared.go index 98f6e02..7c06986 100644 --- a/cmds/shared.go +++ b/cmds/shared.go @@ -17,10 +17,11 @@ type CobraRunFunc func(cmd *cobra.Command, args []string) // the database driver chosen by the driver flag at runtime, and a pointer to the // output file, if one is specified with a flag. type CmdData struct { - Tables []dbdrivers.Table - PkgName string - OutFolder string - Interface dbdrivers.Interface + Tables []dbdrivers.Table + PkgName string + OutFolder string + Interface dbdrivers.Interface + DriverName string } // tplData is used to pass data to the template diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go index 517a9a0..9b627f7 100644 --- a/cmds/sqlboiler.go +++ b/cmds/sqlboiler.go @@ -124,6 +124,8 @@ func initInterface() { if cmdData.Interface == nil { errorQuit(errors.New("An invalid driver name was provided")) } + + cmdData.DriverName = driverName } // initTables will create a string slice out of the passed in table flag value diff --git a/cmds/templates_test/main.tpl b/cmds/templates_test/main.tpl deleted file mode 100644 index 7d256c5..0000000 --- a/cmds/templates_test/main.tpl +++ /dev/null @@ -1,49 +0,0 @@ -type PostgresCfg struct { - User string `toml:"user"` - Pass string `toml:"pass"` - Host string `toml:"host"` - Port int `toml:"port"` - DBName string `toml:"dbname"` -} - -type Config struct { - Postgres PostgresCfg `toml:"postgres"` - TestPostgres *PostgresCfg `toml:"postgres_test"` -} - -var cfg *Config - -func LoadConfigFile(filename string) { - _, err := toml.DecodeFile(filename, &cfg) - - if os.IsNotExist(err) { - fmt.Fatalf("Failed to find the toml configuration file %s: %s", filename, err) - } - - if err != nil { - fmt.Fatalf("Failed to decode toml configuration file:", err) - } - - if cfg.TestPostgres != nil { - if cfg.TestPostgres.User == "" || cfg.TestPostgres.Pass == "" || - cfg.TestPostgres.Host == "" || cfg.TestPostgres.Port == 0 || - cfg.TestPostgres.DBName == "" || cfg.Postgres.DBName == cfg.TestPostgres.DBName { - cfg.TestPostgres = nil - } - } - - if cfg.TestPostgres == nil { - fmt.Fatalf("Failed to load config.toml postgres_test config") - } -} - -func TestMain(m *testing.M) { - setup() - code := m.Run() - // shutdown - os.Exit(code) -} - -func setup() { - LoadConfigFile("../config.toml") -} diff --git a/cmds/templates_test/postgres_main.tpl b/cmds/templates_test/postgres_main.tpl new file mode 100644 index 0000000..028cb53 --- /dev/null +++ b/cmds/templates_test/postgres_main.tpl @@ -0,0 +1,115 @@ +type PostgresCfg struct { + User string `toml:"user"` + Pass string `toml:"pass"` + Host string `toml:"host"` + Port int `toml:"port"` + DBName string `toml:"dbname"` +} + +type Config struct { + Postgres PostgresCfg `toml:"postgres"` + TestPostgres *PostgresCfg `toml:"postgres_test"` +} + +var cfg *Config + +var dbConn *sql.DB + +func DBConnect(user, pass, dbname, host string, port int) error { + connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d", + user, pass, dbname, host, port) + + var err error + dbConn, err = sql.Open("postgres", connStr) + if err != nil { + return err + } + + return nil +} + +func LoadConfigFile(filename string) error { + _, err := toml.DecodeFile(filename, &cfg) + + if os.IsNotExist(err) { + return fmt.Errorf("Failed to find the toml configuration file %s: %s", filename, err) + } + + if err != nil { + return fmt.Errorf("Failed to decode toml configuration file:", err) + } + + if cfg.TestPostgres != nil { + if cfg.TestPostgres.User == "" || cfg.TestPostgres.Pass == "" || + cfg.TestPostgres.Host == "" || cfg.TestPostgres.Port == 0 || + cfg.TestPostgres.DBName == "" || cfg.Postgres.DBName == cfg.TestPostgres.DBName { + cfg.TestPostgres = nil + } + } + + if cfg.TestPostgres == nil { + return errors.New("Failed to load config.toml postgres_test config") + } + + return nil +} + +func TestMain(m *testing.M) { + err := setup() + if err != nil { + os.Exit(-1) + } + code := m.Run() + // shutdown + os.Exit(code) +} + +func setup() error { + err := LoadConfigFile("../config.toml") + if err != nil { + return fmt.Errorf("Unable to load config file: %s", err) + } + + fhSchema, err := ioutil.TempFile(os.TempDir(), "sqlboilerschema") + if err != nil { + return fmt.Errorf("Unable to create sqlboiler schema tmp file: %s", err) + } + defer os.Remove(fhSchema.Name()) + + passDir, err := ioutil.TempDir(os.TempDir(), "sqlboiler") + if err != nil { + return fmt.Errorf("Unable to create sqlboiler tmp dir for postgres pw file: %s", err) + } + defer os.RemoveAll(passDir) + + // Write the postgres user password to a tmp file for pg_dump + pwBytes := []byte(fmt.Sprintf("*:*:*:*:%s", cfg.Postgres.Pass)) + passFilePath := passDir + "/pwfile" + + err = ioutil.WriteFile(passFilePath, pwBytes, 0600) + if err != nil { + return fmt.Errorf("Unable to create pwfile in passDir: %s", err) + } + + params := []string{ + fmt.Sprintf(`--host=%s`, cfg.Postgres.Host), + fmt.Sprintf(`--port=%d`, cfg.Postgres.Port), + fmt.Sprintf(`--username=%s`, cfg.Postgres.User), + "--schema-only", + cfg.Postgres.DBName, + } + + errBuf := bytes.Buffer{} + cmd := exec.Command("pg_dump", params...) + cmd.Stderr = &errBuf + cmd.Stdout = fhSchema + cmd.Env = append(os.Environ(), fmt.Sprintf(`PGPASSFILE=%s`, passFilePath)) + + if err := cmd.Run(); err != nil { + fmt.Printf("pg_dump exec failed: %s\n\n%s\n", err, errBuf.String()) + } + + err = DBConnect(cfg.Postgres.User, cfg.Postgres.Pass, cfg.Postgres.DBName, cfg.Postgres.Host, cfg.Postgres.Port) + _, err = dbConn.Exec(fmt.Sprintf(`CREATE DATABASE %s WITH ENCODING 'UTF8'`, cfg.TestPostgres.DBName)) + return nil +}