Fix query string building a little bit

This commit is contained in:
Aaron L 2016-07-31 20:29:28 -07:00
parent edecf1b704
commit 1eade96bcd
2 changed files with 28 additions and 4 deletions

View file

@ -3,6 +3,7 @@ package drivers
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"strings"
// Side-effect import sql driver // Side-effect import sql driver
_ "github.com/lib/pq" _ "github.com/lib/pq"
@ -22,13 +23,37 @@ type PostgresDriver struct {
// the database connection once an object has been obtained. // the database connection once an object has been obtained.
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver { func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
driver := PostgresDriver{ driver := PostgresDriver{
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s", connStr: BuildQueryString(user, pass, dbname, host, port, sslmode),
user, pass, dbname, host, port, sslmode),
} }
return &driver return &driver
} }
// BuildQueryString for Postgres
func BuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
parts := []string{}
if len(user) != 0 {
parts = append(parts, fmt.Sprintf("user=%s", user))
}
if len(pass) != 0 {
parts = append(parts, fmt.Sprintf("password=%s", pass))
}
if len(dbname) != 0 {
parts = append(parts, fmt.Sprintf("dbname=%s", dbname))
}
if len(host) != 0 {
parts = append(parts, fmt.Sprintf("host=%s", host))
}
if port != 0 {
parts = append(parts, fmt.Sprintf("port=%d", port))
}
if len(sslmode) != 0 {
parts = append(parts, fmt.Sprintf("sslmode=%s", sslmode))
}
return strings.Join(parts, " ")
}
// Open opens the database connection using the connection string // Open opens the database connection using the connection string
func (p *PostgresDriver) Open() error { func (p *PostgresDriver) Open() error {
var err error var err error

View file

@ -64,7 +64,7 @@ func main() {
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to") rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package") rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
viper.SetDefault("postgres.ssl_mode", "required") viper.SetDefault("postgres.sslmode", "required")
viper.BindPFlags(rootCmd.PersistentFlags()) viper.BindPFlags(rootCmd.PersistentFlags())
if err := rootCmd.Execute(); err != nil { if err := rootCmd.Execute(); err != nil {
@ -130,7 +130,6 @@ func preRun(cmd *cobra.Command, args []string) error {
err = vala.BeginValidation().Validate( err = vala.BeginValidation().Validate(
vala.StringNotEmpty(cmdConfig.Postgres.User, "postgres.user"), vala.StringNotEmpty(cmdConfig.Postgres.User, "postgres.user"),
vala.StringNotEmpty(cmdConfig.Postgres.Pass, "postgres.pass"),
vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"), vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"),
vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")), vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")),
vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"), vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"),