Add postgres sslmode to configuration.

This commit is contained in:
Aaron L 2016-07-11 15:17:49 -07:00
parent 8cb8d1348a
commit 2b666e74de
5 changed files with 35 additions and 23 deletions

View file

@ -20,10 +20,10 @@ type PostgresDriver struct {
// returns a pointer to a PostgresDriver object. Note that it is required to
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close
// the database connection once an object has been obtained.
func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver {
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
driver := PostgresDriver{
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
user, pass, dbname, host, port),
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
user, pass, dbname, host, port, sslmode),
}
return &driver

View file

@ -12,9 +12,10 @@ type Config struct {
// PostgresConfig configures a postgres database
type PostgresConfig struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
SSLMode string `toml:"sslmode"`
}

13
main.go
View file

@ -63,6 +63,7 @@ func main() {
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
viper.SetDefault("postgres.ssl_mode", "required")
viper.BindPFlags(rootCmd.PersistentFlags())
if err := rootCmd.Execute(); err != nil {
@ -110,11 +111,12 @@ func preRun(cmd *cobra.Command, args []string) error {
if viper.IsSet("postgres.dbname") {
cmdConfig.Postgres = PostgresConfig{
User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"),
Host: viper.GetString("postgres.host"),
Port: viper.GetInt("postgres.port"),
DBName: viper.GetString("postgres.dbname"),
User: viper.GetString("postgres.user"),
Pass: viper.GetString("postgres.pass"),
Host: viper.GetString("postgres.host"),
Port: viper.GetInt("postgres.port"),
DBName: viper.GetString("postgres.dbname"),
SSLMode: viper.GetString("postgres.sslmode"),
}
err = vala.BeginValidation().Validate(
@ -123,6 +125,7 @@ func preRun(cmd *cobra.Command, args []string) error {
vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(cmdConfig.Postgres.SSLMode, "postgres.sslmode"),
).Check()
if err != nil {

View file

@ -177,6 +177,7 @@ func (s *State) initDriver(driverName string) error {
s.Config.Postgres.DBName,
s.Config.Postgres.Host,
s.Config.Postgres.Port,
s.Config.Postgres.SSLMode,
)
}

View file

@ -1,9 +1,10 @@
type PostgresCfg struct {
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
User string `toml:"user"`
Pass string `toml:"pass"`
Host string `toml:"host"`
Port int `toml:"port"`
DBName string `toml:"dbname"`
SSLMode string `toml:"sslmode"`
}
type Config struct {
@ -79,7 +80,7 @@ func dropTestDB() error {
}
}
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port)
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil {
return err
}
@ -93,9 +94,9 @@ func dropTestDB() error {
}
// DBConnect connects to a database and returns the handle.
func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) {
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
user, pass, dbname, host, port)
func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) {
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
user, pass, dbname, host, port, sslmode)
return sql.Open("postgres", connStr)
}
@ -112,12 +113,15 @@ func setup() error {
return fmt.Errorf("Unable to load config file: %s", err)
}
viper.SetDefault("postgres.sslmode", "required")
// Create a randomized test configuration object.
testCfg.Postgres.Host = viper.GetString("postgres.host")
testCfg.Postgres.Port = viper.GetInt("postgres.port")
testCfg.Postgres.User = viper.GetString("postgres.user")
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
err = vala.BeginValidation().Validate(
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
@ -125,6 +129,7 @@ func setup() error {
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
).Check()
if err != nil {
@ -192,6 +197,7 @@ func setup() error {
viper.GetString("postgres.dbname"),
viper.GetString("postgres.host"),
viper.GetInt("postgres.port"),
viper.GetString("postgres.sslmode"),
)
if err != nil {
return err
@ -209,7 +215,7 @@ func setup() error {
}
// Connect to the generated test db
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port)
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
if err != nil {
return err
}
@ -221,6 +227,7 @@ func setup() error {
testCfg.Postgres.DBName,
testCfg.Postgres.User,
testCfg.Postgres.Pass,
testCfg.Postgres.SSLMode,
))
testPassFilePath := passDir + "/testpwfile"