diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index 637e1d8..4016912 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -20,10 +20,10 @@ type PostgresDriver struct { // returns a pointer to a PostgresDriver object. Note that it is required to // call PostgresDriver.Open() and PostgresDriver.Close() to open and close // the database connection once an object has been obtained. -func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver { +func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver { driver := PostgresDriver{ - connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d", - user, pass, dbname, host, port), + connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s", + user, pass, dbname, host, port, sslmode), } return &driver diff --git a/config.go b/config.go index ac17593..4a05617 100644 --- a/config.go +++ b/config.go @@ -12,9 +12,10 @@ type Config struct { // PostgresConfig configures a postgres database type PostgresConfig struct { - User string `toml:"user"` - Pass string `toml:"pass"` - Host string `toml:"host"` - Port int `toml:"port"` - DBName string `toml:"dbname"` + User string `toml:"user"` + Pass string `toml:"pass"` + Host string `toml:"host"` + Port int `toml:"port"` + DBName string `toml:"dbname"` + SSLMode string `toml:"sslmode"` } diff --git a/main.go b/main.go index 64de52d..ae8ce23 100644 --- a/main.go +++ b/main.go @@ -63,6 +63,7 @@ func main() { rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") + viper.SetDefault("postgres.ssl_mode", "required") viper.BindPFlags(rootCmd.PersistentFlags()) if err := rootCmd.Execute(); err != nil { @@ -110,11 +111,12 @@ func preRun(cmd *cobra.Command, args []string) error { if viper.IsSet("postgres.dbname") { cmdConfig.Postgres = PostgresConfig{ - User: viper.GetString("postgres.user"), - Pass: viper.GetString("postgres.pass"), - Host: viper.GetString("postgres.host"), - Port: viper.GetInt("postgres.port"), - DBName: viper.GetString("postgres.dbname"), + User: viper.GetString("postgres.user"), + Pass: viper.GetString("postgres.pass"), + Host: viper.GetString("postgres.host"), + Port: viper.GetInt("postgres.port"), + DBName: viper.GetString("postgres.dbname"), + SSLMode: viper.GetString("postgres.sslmode"), } err = vala.BeginValidation().Validate( @@ -123,6 +125,7 @@ func preRun(cmd *cobra.Command, args []string) error { vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"), vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")), vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"), + vala.StringNotEmpty(cmdConfig.Postgres.SSLMode, "postgres.sslmode"), ).Check() if err != nil { diff --git a/sqlboiler.go b/sqlboiler.go index fae50d8..f3eab79 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -177,6 +177,7 @@ func (s *State) initDriver(driverName string) error { s.Config.Postgres.DBName, s.Config.Postgres.Host, s.Config.Postgres.Port, + s.Config.Postgres.SSLMode, ) } diff --git a/templates_test/main_test/postgres_main.tpl b/templates_test/main_test/postgres_main.tpl index e05185d..40aaa54 100644 --- a/templates_test/main_test/postgres_main.tpl +++ b/templates_test/main_test/postgres_main.tpl @@ -1,9 +1,10 @@ type PostgresCfg struct { - User string `toml:"user"` - Pass string `toml:"pass"` - Host string `toml:"host"` - Port int `toml:"port"` - DBName string `toml:"dbname"` + User string `toml:"user"` + Pass string `toml:"pass"` + Host string `toml:"host"` + Port int `toml:"port"` + DBName string `toml:"dbname"` + SSLMode string `toml:"sslmode"` } type Config struct { @@ -79,7 +80,7 @@ func dropTestDB() error { } } - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port) + dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) if err != nil { return err } @@ -93,9 +94,9 @@ func dropTestDB() error { } // DBConnect connects to a database and returns the handle. -func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) { - connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d", - user, pass, dbname, host, port) +func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) { + connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s", + user, pass, dbname, host, port, sslmode) return sql.Open("postgres", connStr) } @@ -112,12 +113,15 @@ func setup() error { return fmt.Errorf("Unable to load config file: %s", err) } + viper.SetDefault("postgres.sslmode", "required") + // Create a randomized test configuration object. testCfg.Postgres.Host = viper.GetString("postgres.host") testCfg.Postgres.Port = viper.GetInt("postgres.port") testCfg.Postgres.User = viper.GetString("postgres.user") testCfg.Postgres.Pass = viper.GetString("postgres.pass") testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname")) + testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode") err = vala.BeginValidation().Validate( vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"), @@ -125,6 +129,7 @@ func setup() error { vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"), vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")), vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"), + vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"), ).Check() if err != nil { @@ -192,6 +197,7 @@ func setup() error { viper.GetString("postgres.dbname"), viper.GetString("postgres.host"), viper.GetInt("postgres.port"), + viper.GetString("postgres.sslmode"), ) if err != nil { return err @@ -209,7 +215,7 @@ func setup() error { } // Connect to the generated test db - dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port) + dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode) if err != nil { return err } @@ -221,6 +227,7 @@ func setup() error { testCfg.Postgres.DBName, testCfg.Postgres.User, testCfg.Postgres.Pass, + testCfg.Postgres.SSLMode, )) testPassFilePath := passDir + "/testpwfile"