Add postgres sslmode to configuration.
This commit is contained in:
parent
8cb8d1348a
commit
2b666e74de
5 changed files with 35 additions and 23 deletions
|
@ -20,10 +20,10 @@ type PostgresDriver struct {
|
|||
// returns a pointer to a PostgresDriver object. Note that it is required to
|
||||
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close
|
||||
// the database connection once an object has been obtained.
|
||||
func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver {
|
||||
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
|
||||
driver := PostgresDriver{
|
||||
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
||||
user, pass, dbname, host, port),
|
||||
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||
user, pass, dbname, host, port, sslmode),
|
||||
}
|
||||
|
||||
return &driver
|
||||
|
|
11
config.go
11
config.go
|
@ -12,9 +12,10 @@ type Config struct {
|
|||
|
||||
// PostgresConfig configures a postgres database
|
||||
type PostgresConfig struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
SSLMode string `toml:"sslmode"`
|
||||
}
|
||||
|
|
13
main.go
13
main.go
|
@ -63,6 +63,7 @@ func main() {
|
|||
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
||||
|
||||
viper.SetDefault("postgres.ssl_mode", "required")
|
||||
viper.BindPFlags(rootCmd.PersistentFlags())
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
|
@ -110,11 +111,12 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
|
||||
if viper.IsSet("postgres.dbname") {
|
||||
cmdConfig.Postgres = PostgresConfig{
|
||||
User: viper.GetString("postgres.user"),
|
||||
Pass: viper.GetString("postgres.pass"),
|
||||
Host: viper.GetString("postgres.host"),
|
||||
Port: viper.GetInt("postgres.port"),
|
||||
DBName: viper.GetString("postgres.dbname"),
|
||||
User: viper.GetString("postgres.user"),
|
||||
Pass: viper.GetString("postgres.pass"),
|
||||
Host: viper.GetString("postgres.host"),
|
||||
Port: viper.GetInt("postgres.port"),
|
||||
DBName: viper.GetString("postgres.dbname"),
|
||||
SSLMode: viper.GetString("postgres.sslmode"),
|
||||
}
|
||||
|
||||
err = vala.BeginValidation().Validate(
|
||||
|
@ -123,6 +125,7 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"),
|
||||
vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")),
|
||||
vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"),
|
||||
vala.StringNotEmpty(cmdConfig.Postgres.SSLMode, "postgres.sslmode"),
|
||||
).Check()
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -177,6 +177,7 @@ func (s *State) initDriver(driverName string) error {
|
|||
s.Config.Postgres.DBName,
|
||||
s.Config.Postgres.Host,
|
||||
s.Config.Postgres.Port,
|
||||
s.Config.Postgres.SSLMode,
|
||||
)
|
||||
}
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
type PostgresCfg struct {
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
User string `toml:"user"`
|
||||
Pass string `toml:"pass"`
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
DBName string `toml:"dbname"`
|
||||
SSLMode string `toml:"sslmode"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
|
@ -79,7 +80,7 @@ func dropTestDB() error {
|
|||
}
|
||||
}
|
||||
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port)
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -93,9 +94,9 @@ func dropTestDB() error {
|
|||
}
|
||||
|
||||
// DBConnect connects to a database and returns the handle.
|
||||
func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) {
|
||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
||||
user, pass, dbname, host, port)
|
||||
func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) {
|
||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||
user, pass, dbname, host, port, sslmode)
|
||||
|
||||
return sql.Open("postgres", connStr)
|
||||
}
|
||||
|
@ -112,12 +113,15 @@ func setup() error {
|
|||
return fmt.Errorf("Unable to load config file: %s", err)
|
||||
}
|
||||
|
||||
viper.SetDefault("postgres.sslmode", "required")
|
||||
|
||||
// Create a randomized test configuration object.
|
||||
testCfg.Postgres.Host = viper.GetString("postgres.host")
|
||||
testCfg.Postgres.Port = viper.GetInt("postgres.port")
|
||||
testCfg.Postgres.User = viper.GetString("postgres.user")
|
||||
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
|
||||
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
|
||||
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
|
||||
|
||||
err = vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
|
||||
|
@ -125,6 +129,7 @@ func setup() error {
|
|||
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
|
||||
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
|
||||
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
|
||||
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
|
||||
).Check()
|
||||
|
||||
if err != nil {
|
||||
|
@ -192,6 +197,7 @@ func setup() error {
|
|||
viper.GetString("postgres.dbname"),
|
||||
viper.GetString("postgres.host"),
|
||||
viper.GetInt("postgres.port"),
|
||||
viper.GetString("postgres.sslmode"),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -209,7 +215,7 @@ func setup() error {
|
|||
}
|
||||
|
||||
// Connect to the generated test db
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port)
|
||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -221,6 +227,7 @@ func setup() error {
|
|||
testCfg.Postgres.DBName,
|
||||
testCfg.Postgres.User,
|
||||
testCfg.Postgres.Pass,
|
||||
testCfg.Postgres.SSLMode,
|
||||
))
|
||||
|
||||
testPassFilePath := passDir + "/testpwfile"
|
||||
|
|
Loading…
Add table
Reference in a new issue