Merge branch 'master' of github.com:nullbio/sqlboiler
This commit is contained in:
commit
a3ea7a3c0c
9 changed files with 61 additions and 17 deletions
|
@ -3,6 +3,7 @@ package drivers
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
// Side-effect import sql driver
|
||||
_ "github.com/lib/pq"
|
||||
|
@ -22,13 +23,37 @@ type PostgresDriver struct {
|
|||
// the database connection once an object has been obtained.
|
||||
func NewPostgresDriver(user, pass, dbname, host string, port int, sslmode string) *PostgresDriver {
|
||||
driver := PostgresDriver{
|
||||
connStr: fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||
user, pass, dbname, host, port, sslmode),
|
||||
connStr: BuildQueryString(user, pass, dbname, host, port, sslmode),
|
||||
}
|
||||
|
||||
return &driver
|
||||
}
|
||||
|
||||
// BuildQueryString for Postgres
|
||||
func BuildQueryString(user, pass, dbname, host string, port int, sslmode string) string {
|
||||
parts := []string{}
|
||||
if len(user) != 0 {
|
||||
parts = append(parts, fmt.Sprintf("user=%s", user))
|
||||
}
|
||||
if len(pass) != 0 {
|
||||
parts = append(parts, fmt.Sprintf("password=%s", pass))
|
||||
}
|
||||
if len(dbname) != 0 {
|
||||
parts = append(parts, fmt.Sprintf("dbname=%s", dbname))
|
||||
}
|
||||
if len(host) != 0 {
|
||||
parts = append(parts, fmt.Sprintf("host=%s", host))
|
||||
}
|
||||
if port != 0 {
|
||||
parts = append(parts, fmt.Sprintf("port=%d", port))
|
||||
}
|
||||
if len(sslmode) != 0 {
|
||||
parts = append(parts, fmt.Sprintf("sslmode=%s", sslmode))
|
||||
}
|
||||
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
// Open opens the database connection using the connection string
|
||||
func (p *PostgresDriver) Open() error {
|
||||
var err error
|
||||
|
|
|
@ -1 +1 @@
|
|||
SELECT "count(*) as ab, thing as bd","stuff" FROM "t";
|
||||
SELECT * FROM "q" ORDER BY a ASC,b DESC;
|
1
boil/_fixtures/03.sql
Normal file
1
boil/_fixtures/03.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "count(*) as ab, thing as bd","stuff" FROM "t";
|
|
@ -86,6 +86,11 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
|||
where, args := whereClause(q)
|
||||
buf.WriteString(where)
|
||||
|
||||
if len(q.orderBy) != 0 {
|
||||
buf.WriteString(" ORDER BY ")
|
||||
buf.WriteString(strings.Join(q.orderBy, ","))
|
||||
}
|
||||
|
||||
if q.limit != 0 {
|
||||
fmt.Fprintf(buf, " LIMIT %d", q.limit)
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ func TestBuildQuery(t *testing.T) {
|
|||
}{
|
||||
{&Query{from: "t"}, nil},
|
||||
{&Query{from: "q", limit: 5, offset: 6}, nil},
|
||||
{&Query{from: "q", orderBy: []string{"a ASC", "b DESC"}}, nil},
|
||||
{&Query{from: "t", selectCols: []string{"count(*) as ab, thing as bd", "stuff"}}, nil},
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/nullbio/sqlboiler/strmangle"
|
||||
"github.com/satori/go.uuid"
|
||||
"gopkg.in/nullbio/null.v4"
|
||||
)
|
||||
|
||||
|
@ -241,9 +242,12 @@ func randomizeField(field reflect.Value, fieldType string, includeInvalid bool)
|
|||
}
|
||||
case typeNullString:
|
||||
if b {
|
||||
if fieldType == "interval" {
|
||||
switch fieldType {
|
||||
case "interval":
|
||||
newVal = null.NewString(strconv.Itoa((sd.nextInt()%26)+2)+" days", b)
|
||||
} else {
|
||||
case "uuid":
|
||||
newVal = null.NewString(uuid.NewV4().String(), b)
|
||||
default:
|
||||
newVal = null.NewString(randStr(1, sd.nextInt()), b)
|
||||
}
|
||||
} else {
|
||||
|
@ -365,9 +369,12 @@ func randomizeField(field reflect.Value, fieldType string, includeInvalid bool)
|
|||
}
|
||||
newVal = b
|
||||
case reflect.String:
|
||||
if fieldType == "interval" {
|
||||
switch fieldType {
|
||||
case "interval":
|
||||
newVal = strconv.Itoa((sd.nextInt()%26)+2) + " days"
|
||||
} else {
|
||||
case "uuid":
|
||||
newVal = uuid.NewV4().String()
|
||||
default:
|
||||
newVal = randStr(1, sd.nextInt())
|
||||
}
|
||||
case reflect.Slice:
|
||||
|
|
|
@ -214,11 +214,13 @@ var defaultTestMainImports = map[string]imports{
|
|||
`"io/ioutil"`,
|
||||
`"bytes"`,
|
||||
`"database/sql"`,
|
||||
`"path/filepath"`,
|
||||
`"time"`,
|
||||
`"math/rand"`,
|
||||
},
|
||||
thirdParty: importList{
|
||||
`"github.com/nullbio/sqlboiler/boil"`,
|
||||
`"github.com/nullbio/sqlboiler/bdb/drivers"`,
|
||||
`_ "github.com/lib/pq"`,
|
||||
`"github.com/spf13/viper"`,
|
||||
`"github.com/kat-co/vala"`,
|
||||
|
|
3
main.go
3
main.go
|
@ -64,7 +64,7 @@ func main() {
|
|||
rootCmd.PersistentFlags().StringP("output", "o", "models", "The name of the folder to output to")
|
||||
rootCmd.PersistentFlags().StringP("pkgname", "p", "models", "The name you wish to assign to your generated package")
|
||||
|
||||
viper.SetDefault("postgres.ssl_mode", "required")
|
||||
viper.SetDefault("postgres.sslmode", "required")
|
||||
viper.BindPFlags(rootCmd.PersistentFlags())
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
|
@ -130,7 +130,6 @@ func preRun(cmd *cobra.Command, args []string) error {
|
|||
|
||||
err = vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(cmdConfig.Postgres.User, "postgres.user"),
|
||||
vala.StringNotEmpty(cmdConfig.Postgres.Pass, "postgres.pass"),
|
||||
vala.StringNotEmpty(cmdConfig.Postgres.Host, "postgres.host"),
|
||||
vala.Not(vala.Equals(cmdConfig.Postgres.Port, 0, "postgres.port")),
|
||||
vala.StringNotEmpty(cmdConfig.Postgres.DBName, "postgres.dbname"),
|
||||
|
|
|
@ -97,8 +97,7 @@ func dropTestDB() error {
|
|||
|
||||
// DBConnect connects to a database and returns the handle.
|
||||
func DBConnect(user, pass, dbname, host string, port int, sslmode string) (*sql.DB, error) {
|
||||
connStr := fmt.Sprintf("user=%s password=%s dbname=%s host=%s port=%d sslmode=%s",
|
||||
user, pass, dbname, host, port, sslmode)
|
||||
connStr := drivers.BuildQueryString(user, pass, dbname, host, port, sslmode)
|
||||
|
||||
return sql.Open("postgres", connStr)
|
||||
}
|
||||
|
@ -133,7 +132,6 @@ func setup() error {
|
|||
|
||||
err = vala.BeginValidation().Validate(
|
||||
vala.StringNotEmpty(testCfg.Postgres.User, "postgres.user"),
|
||||
vala.StringNotEmpty(testCfg.Postgres.Pass, "postgres.pass"),
|
||||
vala.StringNotEmpty(testCfg.Postgres.Host, "postgres.host"),
|
||||
vala.Not(vala.Equals(testCfg.Postgres.Port, 0, "postgres.port")),
|
||||
vala.StringNotEmpty(testCfg.Postgres.DBName, "postgres.dbname"),
|
||||
|
@ -163,15 +161,18 @@ func setup() error {
|
|||
defer os.RemoveAll(passDir)
|
||||
|
||||
// Write the postgres user password to a tmp file for pg_dump
|
||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
|
||||
pwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
|
||||
viper.GetString("postgres.host"),
|
||||
viper.GetInt("postgres.port"),
|
||||
viper.GetString("postgres.dbname"),
|
||||
viper.GetString("postgres.user"),
|
||||
viper.GetString("postgres.pass"),
|
||||
))
|
||||
|
||||
passFilePath := passDir + "/pwfile"
|
||||
if pw := viper.GetString("postgres.pass"); len(pw) > 0 {
|
||||
pwBytes = []byte(fmt.Sprintf("%s:%s", pwBytes, pw))
|
||||
}
|
||||
|
||||
passFilePath := filepath.Join(passDir, "pwfile")
|
||||
|
||||
err = ioutil.WriteFile(passFilePath, pwBytes, 0600)
|
||||
if err != nil {
|
||||
|
@ -229,14 +230,17 @@ func setup() error {
|
|||
}
|
||||
|
||||
// Write the test config credentials to a tmp file for pg_dump
|
||||
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s:%s",
|
||||
testPwBytes := []byte(fmt.Sprintf("%s:%d:%s:%s",
|
||||
testCfg.Postgres.Host,
|
||||
testCfg.Postgres.Port,
|
||||
testCfg.Postgres.DBName,
|
||||
testCfg.Postgres.User,
|
||||
testCfg.Postgres.Pass,
|
||||
))
|
||||
|
||||
if len(testCfg.Postgres.Pass) > 0 {
|
||||
testPwBytes = []byte(fmt.Sprintf("%s:%s", testPwBytes, testCfg.Postgres.Pass))
|
||||
}
|
||||
|
||||
testPassFilePath := passDir + "/testpwfile"
|
||||
|
||||
err = ioutil.WriteFile(testPassFilePath, testPwBytes, 0600)
|
||||
|
|
Loading…
Add table
Reference in a new issue