sqlboiler/boil/query_builders.go

371 lines
7.7 KiB
Go
Raw Normal View History

2016-08-06 23:42:22 +02:00
package boil
import (
"bytes"
"fmt"
"regexp"
"sort"
2016-08-06 23:42:22 +02:00
"strings"
2016-08-09 09:59:30 +02:00
"github.com/vattle/sqlboiler/strmangle"
2016-08-06 23:42:22 +02:00
)
var (
2016-08-07 23:09:56 +02:00
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
2016-08-06 23:42:22 +02:00
)
func buildQuery(q *Query) (string, []interface{}) {
var buf *bytes.Buffer
var args []interface{}
switch {
case len(q.plainSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args
case q.delete:
buf, args = buildDeleteQuery(q)
case len(q.update) > 0:
buf, args = buildUpdateQuery(q)
default:
buf, args = buildSelectQuery(q)
}
return buf.String(), args
}
func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("SELECT ")
2016-08-07 00:10:35 +02:00
// Wrap the select in the modifier function
hasModFunc := len(q.modFunction) != 0
if hasModFunc {
fmt.Fprintf(buf, "%s(", q.modFunction)
2016-08-06 23:42:22 +02:00
}
2016-08-07 00:10:35 +02:00
hasSelectCols := len(q.selectCols) != 0
hasJoins := len(q.joins) != 0
2016-08-07 23:09:56 +02:00
if hasSelectCols && hasJoins && !hasModFunc {
selectColsWithAs := writeAsStatements(q)
// Don't identQuoteSlice - writeAsStatements does this
buf.WriteString(strings.Join(selectColsWithAs, ", "))
2016-08-07 00:10:35 +02:00
} else if hasSelectCols {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), ", "))
} else if hasJoins {
selectColsWithStars := writeStars(q)
buf.WriteString(strings.Join(selectColsWithStars, ", "))
2016-08-06 23:42:22 +02:00
} else {
2016-08-07 00:10:35 +02:00
buf.WriteByte('*')
2016-08-06 23:42:22 +02:00
}
2016-08-07 00:10:35 +02:00
if hasModFunc {
buf.WriteByte(')')
2016-08-06 23:42:22 +02:00
}
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
var args []interface{}
for _, j := range q.joins {
if j.kind != JoinInner {
panic("only inner joins are supported")
}
fmt.Fprintf(buf, " INNER JOIN %s", j.clause)
args = append(args, j.args...)
}
2016-08-06 23:42:22 +02:00
where, whereArgs := whereClause(q, len(args)+1)
2016-08-06 23:42:22 +02:00
buf.WriteString(where)
args = append(args, whereArgs...)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
}
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
where, args := whereClause(q, 1)
buf.WriteString(where)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
}
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("UPDATE ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ", "))
cols := make(sort.StringSlice, len(q.update))
var args []interface{}
count := 0
for name := range q.update {
cols[count] = name
count++
}
cols.Sort()
for i := 0; i < len(cols); i++ {
args = append(args, q.update[cols[i]])
cols[i] = strmangle.IdentQuote(cols[i])
}
buf.WriteString(fmt.Sprintf(
" SET (%s) = (%s)",
strings.Join(cols, ", "),
strmangle.Placeholders(len(cols), 1, 1)),
)
where, whereArgs := whereClause(q, len(args)+1)
buf.WriteString(where)
args = append(args, whereArgs...)
writeModifiers(q, buf, &args)
buf.WriteByte(';')
return buf, args
}
func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
2016-08-08 09:28:01 +02:00
if len(q.groupBy) != 0 {
fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", "))
2016-08-08 09:28:01 +02:00
}
if len(q.having) != 0 {
fmt.Fprintf(buf, " HAVING ")
for i, j := range q.having {
if i > 0 {
fmt.Fprintf(buf, ", ")
}
fmt.Fprintf(buf, j.clause)
*args = append(*args, j.args...)
}
2016-08-08 09:28:01 +02:00
}
2016-08-06 23:42:22 +02:00
if len(q.orderBy) != 0 {
buf.WriteString(" ORDER BY ")
buf.WriteString(strings.Join(q.orderBy, ", "))
2016-08-06 23:42:22 +02:00
}
if q.limit != 0 {
fmt.Fprintf(buf, " LIMIT %d", q.limit)
}
if q.offset != 0 {
fmt.Fprintf(buf, " OFFSET %d", q.offset)
}
}
2016-08-07 23:09:56 +02:00
func writeStars(q *Query) []string {
cols := make([]string, 0, len(q.from))
for _, f := range q.from {
toks := strings.Split(f, " ")
if len(toks) == 1 {
cols = append(cols, fmt.Sprintf(`%s.*`, strmangle.IdentQuote(toks[0])))
continue
}
alias, name, ok := parseFromClause(toks)
if !ok {
}
if len(alias) != 0 {
name = alias
}
cols = append(cols, fmt.Sprintf(`%s.*`, strmangle.IdentQuote(name)))
}
return cols
}
func writeAsStatements(q *Query) []string {
2016-08-07 00:10:35 +02:00
cols := make([]string, len(q.selectCols))
2016-08-07 23:09:56 +02:00
for i, col := range q.selectCols {
if !rgxIdentifier.MatchString(col) {
cols[i] = col
continue
}
toks := strings.Split(col, ".")
if len(toks) == 1 {
cols[i] = strmangle.IdentQuote(col)
continue
}
asParts := make([]string, len(toks))
for j, tok := range toks {
asParts[j] = strings.Trim(tok, `"`)
2016-08-07 00:10:35 +02:00
}
2016-08-07 23:09:56 +02:00
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(col), strings.Join(asParts, "."))
2016-08-07 00:10:35 +02:00
}
2016-08-07 23:09:56 +02:00
return cols
2016-08-06 23:42:22 +02:00
}
// whereClause parses a where slice and converts it into a
// single WHERE clause like:
// WHERE (a=$1) AND (b=$2)
//
// startAt specifies what number placeholders start at
func whereClause(q *Query, startAt int) (string, []interface{}) {
2016-08-06 23:42:22 +02:00
if len(q.where) == 0 {
return "", nil
}
buf := &bytes.Buffer{}
var args []interface{}
buf.WriteString(" WHERE ")
for i := 0; i < len(q.where); i++ {
buf.WriteString(fmt.Sprintf("(%s)", q.where[i].clause))
2016-08-06 23:42:22 +02:00
args = append(args, q.where[i].args...)
// break on the last loop
if i == len(q.where)-1 {
break
}
if q.where[i].orSeparator {
buf.WriteString(" OR ")
} else {
2016-08-06 23:42:22 +02:00
buf.WriteString(" AND ")
}
}
whereStr := buf.String()
paramBuf := &bytes.Buffer{}
paramIndex := 0
for counter := 1; ; counter++ {
if paramIndex >= len(whereStr) {
break
}
whereStr = whereStr[paramIndex:]
paramIndex = strings.IndexByte(whereStr, '?')
if paramIndex == -1 {
paramBuf.WriteString(whereStr)
break
}
paramBuf.WriteString(whereStr[:paramIndex] + fmt.Sprintf("$%d", counter))
paramIndex++
}
return convertQuestionMarks(buf.String(), startAt), args
}
func convertQuestionMarks(clause string, startAt int) string {
if startAt == 0 {
panic("Not a valid start number.")
}
paramBuf := &bytes.Buffer{}
paramIndex := 0
for ; ; startAt++ {
if paramIndex >= len(clause) {
break
}
clause = clause[paramIndex:]
paramIndex = strings.IndexByte(clause, '?')
if paramIndex == -1 {
paramBuf.WriteString(clause)
break
}
paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt))
paramIndex++
}
return paramBuf.String()
2016-08-06 23:42:22 +02:00
}
// identifierMapping creates a map of all identifiers to potential model names
func identifierMapping(q *Query) map[string]string {
var ids map[string]string
setID := func(alias, name string) {
if ids == nil {
ids = make(map[string]string)
}
ids[alias] = name
}
for _, from := range q.from {
tokens := strings.Split(from, " ")
parseIdentifierClause(tokens, setID)
}
for _, join := range q.joins {
tokens := strings.Split(join.clause, " ")
parseIdentifierClause(tokens, setID)
2016-08-06 23:42:22 +02:00
}
return ids
}
// parseBits takes a set of tokens and looks for something of the form:
// a b
// a as b
// where 'a' and 'b' are valid SQL identifiers
// It only evaluates the first 3 tokens (anything past that is superfluous)
// It stops parsing when it finds "on" or an invalid identifier
func parseIdentifierClause(tokens []string, setID func(string, string)) {
2016-08-07 23:09:56 +02:00
alias, name, ok := parseFromClause(tokens)
if !ok {
panic("could not parse from statement")
}
if len(alias) > 0 {
setID(alias, name)
} else {
setID(name, name)
}
}
2016-08-06 23:42:22 +02:00
2016-08-07 23:09:56 +02:00
func parseFromClause(toks []string) (alias, name string, ok bool) {
if len(toks) > 3 {
toks = toks[:3]
2016-08-06 23:42:22 +02:00
}
2016-08-07 23:09:56 +02:00
sawIdent, sawAs := false, false
for _, tok := range toks {
2016-08-06 23:42:22 +02:00
if t := strings.ToLower(tok); sawIdent && t == "as" {
sawAs = true
continue
} else if sawIdent && t == "on" {
break
}
if !rgxIdentifier.MatchString(tok) {
break
}
if sawIdent || sawAs {
alias = strings.Trim(tok, `"`)
break
}
name = strings.Trim(tok, `"`)
sawIdent = true
2016-08-07 23:09:56 +02:00
ok = true
2016-08-06 23:42:22 +02:00
}
2016-08-07 23:09:56 +02:00
return alias, name, ok
2016-08-06 23:42:22 +02:00
}