Add postgres sslmode to configuration.
This commit is contained in:
parent
8cb8d1348a
commit
2b666e74de
5 changed files with 35 additions and 23 deletions
|
@ -20,10 +20,10 @@ type PostgresDriver struct {
|
||||||
// returns a pointer to a PostgresDriver object. Note that it is required to
|
// returns a pointer to a PostgresDriver object. Note that it is required to
|
||||||
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close
|
// call PostgresDriver.Open() and PostgresDriver.Close() to open and close
|
||||||
// the database connection once an object has been obtained.
|
// the database connection once an object has been obtained.
|
||||||
func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDriver {
|
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
|
||||||
driver := PostgresDriver{
|
driver := PostgresDriver{
|
||||||
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||||
user, pass, dbname, host, port),
|
user, pass, dbname, host, port, sslmode),
|
||||||
}
|
}
|
||||||
|
|
||||||
return &driver
|
return &driver
|
||||||
|
|
11
config.go
11
config.go
|
@ -12,9 +12,10 @@ type Config struct {
|
||||||
|
|
||||||
// PostgresConfig configures a postgres database
|
// PostgresConfig configures a postgres database
|
||||||
type PostgresConfig struct {
|
type PostgresConfig struct {
|
||||||
User string `toml:"user"`
|
User string `toml:"user"`
|
||||||
Pass string `toml:"pass"`
|
Pass string `toml:"pass"`
|
||||||
Host string `toml:"host"`
|
Host string `toml:"host"`
|
||||||
Port int `toml:"port"`
|
Port int `toml:"port"`
|
||||||
DBName string `toml:"dbname"`
|
DBName string `toml:"dbname"`
|
||||||
|
SSLMode string `toml:"sslmode"`
|
||||||
}
|
}
|
||||||
|
|
13
main.go
13
main.go
|
@ -63,6 +63,7 @@ func main() {
|
||||||
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
||||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
||||||
|
|
||||||
|
viper.SetDefault("postgres.ssl_mode", "required")
|
||||||
viper.BindPFlags(rootCmd.PersistentFlags())
|
viper.BindPFlags(rootCmd.PersistentFlags())
|
||||||
|
|
||||||
if err := rootCmd.Execute(); err != nil {
|
if err := rootCmd.Execute(); err != nil {
|
||||||
|
@ -110,11 +111,12 @@ func preRun(cmd *cobra.Command, args []string) error {
|
||||||
|
|
||||||
if viper.IsSet("postgres.dbname") {
|
if viper.IsSet("postgres.dbname") {
|
||||||
cmdConfig.Postgres = PostgresConfig{
|
cmdConfig.Postgres = PostgresConfig{
|
||||||
User: viper.GetString("postgres.user"),
|
User: viper.GetString("postgres.user"),
|
||||||
Pass: viper.GetString("postgres.pass"),
|
Pass: viper.GetString("postgres.pass"),
|
||||||
Host: viper.GetString("postgres.host"),
|
Host: viper.GetString("postgres.host"),
|
||||||
Port: viper.GetInt("postgres.port"),
|
Port: viper.GetInt("postgres.port"),
|
||||||
DBName: viper.GetString("postgres.dbname"),
|
DBName: viper.GetString("postgres.dbname"),
|
||||||
|
SSLMode: viper.GetString("postgres.sslmode"),
|
||||||
}
|
}
|
||||||
|
|
||||||
err = vala.BeginValidation().Validate(
|
err = vala.BeginValidation().Validate(
|
||||||
|
@ -123,6 +125,7 @@ func preRun(cmd *cobra.Command, args []string) error {
|
||||||
vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"),
|
vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"),
|
||||||
vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")),
|
vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")),
|
||||||
vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"),
|
vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"),
|
||||||
|
vala.StringNotEmpty(cmdConfig.Postgres.SSLMode, "postgres.sslmode"),
|
||||||
).Check()
|
).Check()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -177,6 +177,7 @@ func (s *State) initDriver(driverName string) error {
|
||||||
s.Config.Postgres.DBName,
|
s.Config.Postgres.DBName,
|
||||||
s.Config.Postgres.Host,
|
s.Config.Postgres.Host,
|
||||||
s.Config.Postgres.Port,
|
s.Config.Postgres.Port,
|
||||||
|
s.Config.Postgres.SSLMode,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
type PostgresCfg struct {
|
type PostgresCfg struct {
|
||||||
User string `toml:"user"`
|
User string `toml:"user"`
|
||||||
Pass string `toml:"pass"`
|
Pass string `toml:"pass"`
|
||||||
Host string `toml:"host"`
|
Host string `toml:"host"`
|
||||||
Port int `toml:"port"`
|
Port int `toml:"port"`
|
||||||
DBName string `toml:"dbname"`
|
DBName string `toml:"dbname"`
|
||||||
|
SSLMode string `toml:"sslmode"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -79,7 +80,7 @@ func dropTestDB() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port)
|
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, "template1", testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -93,9 +94,9 @@ func dropTestDB() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// DBConnect connects to a database and returns the handle.
|
// DBConnect connects to a database and returns the handle.
|
||||||
func DBConnect(user, pass, dbname, host string, port int) (*sql.DB, error) {
|
func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) {
|
||||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d",
|
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||||
user, pass, dbname, host, port)
|
user, pass, dbname, host, port, sslmode)
|
||||||
|
|
||||||
return sql.Open("postgres", connStr)
|
return sql.Open("postgres", connStr)
|
||||||
}
|
}
|
||||||
|
@ -112,12 +113,15 @@ func setup() error {
|
||||||
return fmt.Errorf("Unable to load config file: %s", err)
|
return fmt.Errorf("Unable to load config file: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
viper.SetDefault("postgres.sslmode", "required")
|
||||||
|
|
||||||
// Create a randomized test configuration object.
|
// Create a randomized test configuration object.
|
||||||
testCfg.Postgres.Host = viper.GetString("postgres.host")
|
testCfg.Postgres.Host = viper.GetString("postgres.host")
|
||||||
testCfg.Postgres.Port = viper.GetInt("postgres.port")
|
testCfg.Postgres.Port = viper.GetInt("postgres.port")
|
||||||
testCfg.Postgres.User = viper.GetString("postgres.user")
|
testCfg.Postgres.User = viper.GetString("postgres.user")
|
||||||
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
|
testCfg.Postgres.Pass = viper.GetString("postgres.pass")
|
||||||
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
|
testCfg.Postgres.DBName = getDBNameHash(viper.GetString("postgres.dbname"))
|
||||||
|
testCfg.Postgres.SSLMode = viper.GetString("postgres.sslmode")
|
||||||
|
|
||||||
err = vala.BeginValidation().Validate(
|
err = vala.BeginValidation().Validate(
|
||||||
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
|
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
|
||||||
|
@ -125,6 +129,7 @@ func setup() error {
|
||||||
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
|
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
|
||||||
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
|
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
|
||||||
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
|
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
|
||||||
|
vala.StringNotEmpty(testCfg.Postgres.SSLMode, "postgres.sslmode"),
|
||||||
).Check()
|
).Check()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -192,6 +197,7 @@ func setup() error {
|
||||||
viper.GetString("postgres.dbname"),
|
viper.GetString("postgres.dbname"),
|
||||||
viper.GetString("postgres.host"),
|
viper.GetString("postgres.host"),
|
||||||
viper.GetInt("postgres.port"),
|
viper.GetInt("postgres.port"),
|
||||||
|
viper.GetString("postgres.sslmode"),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -209,7 +215,7 @@ func setup() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect to the generated test db
|
// Connect to the generated test db
|
||||||
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port)
|
dbConn, err = DBConnect(testCfg.Postgres.User, testCfg.Postgres.Pass, testCfg.Postgres.DBName, testCfg.Postgres.Host, testCfg.Postgres.Port, testCfg.Postgres.SSLMode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -221,6 +227,7 @@ func setup() error {
|
||||||
testCfg.Postgres.DBName,
|
testCfg.Postgres.DBName,
|
||||||
testCfg.Postgres.User,
|
testCfg.Postgres.User,
|
||||||
testCfg.Postgres.Pass,
|
testCfg.Postgres.Pass,
|
||||||
|
testCfg.Postgres.SSLMode,
|
||||||
))
|
))
|
||||||
|
|
||||||
testPassFilePath := passDir + "/testpwfile"
|
testPassFilePath := passDir + "/testpwfile"
|
||||||
|
|
Loading…
Add table
Reference in a new issue