Add postgres sslmode to configuration.

This commit is contained in:
Aaron L 2016-07-11 15:17:49 -07:00
parent 8cb8d1348a
commit 2b666e74de
5 changed files with 35 additions and 23 deletions

View file

@ -20,10 +20,10 @@ type PostgresDriver struct {
// returns a pointer to a PostgresDriver object. Note that it is required to // returns a pointer to a PostgresDriver object. Note that it is required to
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close // call PostgresDriver.Open() and PostgresDriver.Close() to open and close
// the database connection once an object has been obtained. // the database connection once an object has been obtained.
func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver { func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
driver := PostgresDriver{ driver := PostgresDriver{
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d", connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
user, pass, dbname, host, port), user, pass, dbname, host, port, sslmode),
} }
return &driver return &driver

View file

@ -12,9 +12,10 @@ type Config struct {
// PostgresConfig configures a postgres database // PostgresConfig configures a postgres database
type PostgresConfig struct { type PostgresConfig struct {
User string `toml:"user"` User string `toml:"user"`
Pass string `toml:"pass"` Pass string `toml:"pass"`
Host string `toml:"host"` Host string `toml:"host"`
Port int `toml:"port"` Port int `toml:"port"`
DBName string `toml:"dbname"` DBName string `toml:"dbname"`
SSLMode string `toml:"sslmode"`
} }

13
main.go
View file

@ -63,6 +63,7 @@ func main() {
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
viper.SetDefault("postgres.ssl_mode", "required")
viper.BindPFlags(rootCmd.PersistentFlags()) viper.BindPFlags(rootCmd.PersistentFlags())
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
@ -110,11 +111,12 @@ func preRun(cmd *cobra.Command, args []string) error {
if viper.IsSet("postgres.dbname") { if viper.IsSet("postgres.dbname") {
cmdConfig.Postgres = PostgresConfig{ cmdConfig.Postgres = PostgresConfig{
User: viper.GetString("postgres.user"), User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"), Pass: viper.GetString("postgres.pass"),
Host: viper.GetString("postgres.host"), Host: viper.GetString("postgres.host"),
Port: viper.GetInt("postgres.port"), Port: viper.GetInt("postgres.port"),
DBName: viper.GetString("postgres.dbname"), DBName: viper.GetString("postgres.dbname"),
SSLMode: viper.GetString("postgres.sslmode"),
} }
err = vala.BeginValidation().Validate( err = vala.BeginValidation().Validate(
@ -123,6 +125,7 @@ func preRun(cmd *cobra.Command, args []string) error {
vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"), vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")), vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"), vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(cmdConfig.Postgres.SSLMode, "postgres.sslmode"),
).Check() ).Check()
if err != nil { if err != nil {

View file

@ -177,6 +177,7 @@ func (s *State) initDriver(driverName string) error {
s.Config.Postgres.DBName, s.Config.Postgres.DBName,
s.Config.Postgres.Host, s.Config.Postgres.Host,
s.Config.Postgres.Port, s.Config.Postgres.Port,
s.Config.Postgres.SSLMode,
) )
} }

View file

@ -1,9 +1,10 @@
type PostgresCfg struct { type PostgresCfg struct {
User string `toml:"user"` User string `toml:"user"`
Pass string `toml:"pass"` Pass string `toml:"pass"`
Host string `toml:"host"` Host string `toml:"host"`
Port int `toml:"port"` Port int `toml:"port"`
DBName string `toml:"dbname"` DBName string `toml:"dbname"`
SSLMode string `toml:"sslmode"`
} }
type Config struct { type Config struct {
@ -79,7 +80,7 @@ func dropTestDB() error {
} }
} }
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port) dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil { if err != nil {
return err return err
} }
@ -93,9 +94,9 @@ func dropTestDB() error {
} }
// DBConnect connects to a database and returns the handle. // DBConnect connects to a database and returns the handle.
func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) { func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) {
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d", connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
user, pass, dbname, host, port) user, pass, dbname, host, port, sslmode)
return sql.Open("postgres", connStr) return sql.Open("postgres", connStr)
} }
@ -112,12 +113,15 @@ func setup() error {
return fmt.Errorf("Unable to load config file: %s", err) return fmt.Errorf("Unable to load config file: %s", err)
} }
viper.SetDefault("postgres.sslmode", "required")
// Create a randomized test configuration object. // Create a randomized test configuration object.
testCfg.Postgres.Host = viper.GetString("postgres.host") testCfg.Postgres.Host = viper.GetString("postgres.host")
testCfg.Postgres.Port = viper.GetInt("postgres.port") testCfg.Postgres.Port = viper.GetInt("postgres.port")
testCfg.Postgres.User = viper.GetString("postgres.user") testCfg.Postgres.User = viper.GetString("postgres.user")
testCfg.Postgres.Pass = viper.GetString("postgres.pass") testCfg.Postgres.Pass = viper.GetString("postgres.pass")
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname")) testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
err = vala.BeginValidation().Validate( err = vala.BeginValidation().Validate(
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"), vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
@ -125,6 +129,7 @@ func setup() error {
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"), vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")), vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"), vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
).Check() ).Check()
if err != nil { if err != nil {
@ -192,6 +197,7 @@ func setup() error {
viper.GetString("postgres.dbname"), viper.GetString("postgres.dbname"),
viper.GetString("postgres.host"), viper.GetString("postgres.host"),
viper.GetInt("postgres.port"), viper.GetInt("postgres.port"),
viper.GetString("postgres.sslmode"),
) )
if err != nil { if err != nil {
return err return err
@ -209,7 +215,7 @@ func setup() error {
} }
// Connect to the generated test db // Connect to the generated test db
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port) dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil { if err != nil {
return err return err
} }
@ -221,6 +227,7 @@ func setup() error {
testCfg.Postgres.DBName, testCfg.Postgres.DBName,
testCfg.Postgres.User, testCfg.Postgres.User,
testCfg.Postgres.Pass, testCfg.Postgres.Pass,
testCfg.Postgres.SSLMode,
)) ))
testPassFilePath := passDir + "/testpwfile" testPassFilePath := passDir + "/testpwfile"