- Tables that are only comprised of default values and have nothing passed into them on insertion will no longer generate syntax errors.
595 lines
14 KiB
Go
595 lines
14 KiB
Go
package queries
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"regexp"
|
|
"sort"
|
|
"strings"
|
|
|
|
"github.com/vattle/sqlboiler/strmangle"
|
|
)
|
|
|
|
var (
|
|
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
|
|
rgxInClause = regexp.MustCompile(`^(?i)(.*[\s|\)|\?])IN([\s|\(|\?].*)$`)
|
|
)
|
|
|
|
func buildQuery(q *Query) (string, []interface{}) {
|
|
var buf *bytes.Buffer
|
|
var args []interface{}
|
|
|
|
switch {
|
|
case len(q.rawSQL.sql) != 0:
|
|
return q.rawSQL.sql, q.rawSQL.args
|
|
case q.delete:
|
|
buf, args = buildDeleteQuery(q)
|
|
case len(q.update) > 0:
|
|
buf, args = buildUpdateQuery(q)
|
|
default:
|
|
buf, args = buildSelectQuery(q)
|
|
}
|
|
|
|
defer strmangle.PutBuffer(buf)
|
|
|
|
// Cache the generated query for query object re-use
|
|
bufStr := buf.String()
|
|
q.rawSQL.sql = bufStr
|
|
q.rawSQL.args = args
|
|
|
|
return bufStr, args
|
|
}
|
|
|
|
func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
|
buf := strmangle.GetBuffer()
|
|
var args []interface{}
|
|
|
|
buf.WriteString("SELECT ")
|
|
|
|
if q.count {
|
|
buf.WriteString("COUNT(")
|
|
}
|
|
|
|
hasSelectCols := len(q.selectCols) != 0
|
|
hasJoins := len(q.joins) != 0
|
|
if hasJoins && hasSelectCols && !q.count {
|
|
selectColsWithAs := writeAsStatements(q)
|
|
// Don't identQuoteSlice - writeAsStatements does this
|
|
buf.WriteString(strings.Join(selectColsWithAs, ", "))
|
|
} else if hasSelectCols {
|
|
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.selectCols), ", "))
|
|
} else if hasJoins && !q.count {
|
|
selectColsWithStars := writeStars(q)
|
|
buf.WriteString(strings.Join(selectColsWithStars, ", "))
|
|
} else {
|
|
buf.WriteByte('*')
|
|
}
|
|
|
|
// close SQL COUNT function
|
|
if q.count {
|
|
buf.WriteByte(')')
|
|
}
|
|
|
|
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
|
|
|
|
if len(q.joins) > 0 {
|
|
argsLen := len(args)
|
|
joinBuf := strmangle.GetBuffer()
|
|
for _, j := range q.joins {
|
|
if j.kind != JoinInner {
|
|
panic("only inner joins are supported")
|
|
}
|
|
fmt.Fprintf(joinBuf, " INNER JOIN %s", j.clause)
|
|
args = append(args, j.args...)
|
|
}
|
|
var resp string
|
|
if q.dialect.IndexPlaceholders {
|
|
resp, _ = convertQuestionMarks(joinBuf.String(), argsLen+1)
|
|
} else {
|
|
resp = joinBuf.String()
|
|
}
|
|
fmt.Fprintf(buf, resp)
|
|
strmangle.PutBuffer(joinBuf)
|
|
}
|
|
|
|
where, whereArgs := whereClause(q, len(args)+1)
|
|
buf.WriteString(where)
|
|
if len(whereArgs) != 0 {
|
|
args = append(args, whereArgs...)
|
|
}
|
|
|
|
in, inArgs := inClause(q, len(args)+1)
|
|
buf.WriteString(in)
|
|
if len(inArgs) != 0 {
|
|
args = append(args, inArgs...)
|
|
}
|
|
|
|
writeModifiers(q, buf, &args)
|
|
|
|
buf.WriteByte(';')
|
|
return buf, args
|
|
}
|
|
|
|
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
|
var args []interface{}
|
|
buf := strmangle.GetBuffer()
|
|
|
|
buf.WriteString("DELETE FROM ")
|
|
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
|
|
|
|
where, whereArgs := whereClause(q, 1)
|
|
if len(whereArgs) != 0 {
|
|
args = append(args, whereArgs...)
|
|
}
|
|
buf.WriteString(where)
|
|
|
|
in, inArgs := inClause(q, len(args)+1)
|
|
if len(inArgs) != 0 {
|
|
args = append(args, inArgs...)
|
|
}
|
|
buf.WriteString(in)
|
|
|
|
writeModifiers(q, buf, &args)
|
|
|
|
buf.WriteByte(';')
|
|
|
|
return buf, args
|
|
}
|
|
|
|
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
|
buf := strmangle.GetBuffer()
|
|
|
|
buf.WriteString("UPDATE ")
|
|
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, q.from), ", "))
|
|
|
|
cols := make(sort.StringSlice, len(q.update))
|
|
var args []interface{}
|
|
|
|
count := 0
|
|
for name := range q.update {
|
|
cols[count] = name
|
|
count++
|
|
}
|
|
|
|
cols.Sort()
|
|
|
|
for i := 0; i < len(cols); i++ {
|
|
args = append(args, q.update[cols[i]])
|
|
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, cols[i])
|
|
}
|
|
|
|
buf.WriteString(fmt.Sprintf(
|
|
" SET (%s) = (%s)",
|
|
strings.Join(cols, ", "),
|
|
strmangle.Placeholders(q.dialect.IndexPlaceholders, len(cols), 1, 1)),
|
|
)
|
|
|
|
where, whereArgs := whereClause(q, len(args)+1)
|
|
if len(whereArgs) != 0 {
|
|
args = append(args, whereArgs...)
|
|
}
|
|
buf.WriteString(where)
|
|
|
|
in, inArgs := inClause(q, len(args)+1)
|
|
if len(inArgs) != 0 {
|
|
args = append(args, inArgs...)
|
|
}
|
|
buf.WriteString(in)
|
|
|
|
writeModifiers(q, buf, &args)
|
|
|
|
buf.WriteByte(';')
|
|
|
|
return buf, args
|
|
}
|
|
|
|
// BuildUpsertQueryMySQL builds a SQL statement string using the upsertData provided.
|
|
func BuildUpsertQueryMySQL(dia Dialect, tableName string, update, whitelist []string) string {
|
|
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
|
|
|
|
buf := strmangle.GetBuffer()
|
|
defer strmangle.PutBuffer(buf)
|
|
|
|
var columns string
|
|
if len(whitelist) != 0 {
|
|
columns = strings.Join(whitelist, ", ")
|
|
}
|
|
|
|
if len(update) == 0 {
|
|
fmt.Fprintf(
|
|
buf,
|
|
"INSERT IGNORE INTO %s (%s) VALUES (%s)",
|
|
tableName,
|
|
columns,
|
|
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
|
|
)
|
|
return buf.String()
|
|
}
|
|
|
|
fmt.Fprintf(
|
|
buf,
|
|
"INSERT INTO %s (%s) VALUES (%s) ON DUPLICATE KEY UPDATE ",
|
|
tableName,
|
|
columns,
|
|
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1),
|
|
)
|
|
|
|
for i, v := range update {
|
|
if i != 0 {
|
|
buf.WriteByte(',')
|
|
}
|
|
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
|
|
buf.WriteString(quoted)
|
|
buf.WriteString(" = VALUES(")
|
|
buf.WriteString(quoted)
|
|
buf.WriteByte(')')
|
|
}
|
|
|
|
return buf.String()
|
|
}
|
|
|
|
// BuildUpsertQueryPostgres builds a SQL statement string using the upsertData provided.
|
|
func BuildUpsertQueryPostgres(dia Dialect, tableName string, updateOnConflict bool, ret, update, conflict, whitelist []string) string {
|
|
conflict = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, conflict)
|
|
whitelist = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, whitelist)
|
|
ret = strmangle.IdentQuoteSlice(dia.LQ, dia.RQ, ret)
|
|
|
|
buf := strmangle.GetBuffer()
|
|
defer strmangle.PutBuffer(buf)
|
|
|
|
columns := "DEFAULT VALUES"
|
|
if len(whitelist) != 0 {
|
|
columns = fmt.Sprintf("(%s) VALUES (%s)",
|
|
strings.Join(whitelist, ", "),
|
|
strmangle.Placeholders(dia.IndexPlaceholders, len(whitelist), 1, 1))
|
|
}
|
|
|
|
fmt.Fprintf(
|
|
buf,
|
|
"INSERT INTO %s %s ON CONFLICT ",
|
|
tableName,
|
|
columns,
|
|
)
|
|
|
|
if !updateOnConflict || len(update) == 0 {
|
|
buf.WriteString("DO NOTHING")
|
|
} else {
|
|
buf.WriteByte('(')
|
|
buf.WriteString(strings.Join(conflict, ", "))
|
|
buf.WriteString(") DO UPDATE SET ")
|
|
|
|
for i, v := range update {
|
|
if i != 0 {
|
|
buf.WriteByte(',')
|
|
}
|
|
quoted := strmangle.IdentQuote(dia.LQ, dia.RQ, v)
|
|
buf.WriteString(quoted)
|
|
buf.WriteString(" = EXCLUDED.")
|
|
buf.WriteString(quoted)
|
|
}
|
|
}
|
|
|
|
if len(ret) != 0 {
|
|
buf.WriteString(" RETURNING ")
|
|
buf.WriteString(strings.Join(ret, ", "))
|
|
}
|
|
|
|
return buf.String()
|
|
}
|
|
|
|
func writeModifiers(q *Query, buf *bytes.Buffer, args *[]interface{}) {
|
|
if len(q.groupBy) != 0 {
|
|
fmt.Fprintf(buf, " GROUP BY %s", strings.Join(q.groupBy, ", "))
|
|
}
|
|
|
|
if len(q.having) != 0 {
|
|
argsLen := len(*args)
|
|
havingBuf := strmangle.GetBuffer()
|
|
fmt.Fprintf(havingBuf, " HAVING ")
|
|
for i, j := range q.having {
|
|
if i > 0 {
|
|
fmt.Fprintf(havingBuf, ", ")
|
|
}
|
|
fmt.Fprintf(havingBuf, j.clause)
|
|
*args = append(*args, j.args...)
|
|
}
|
|
var resp string
|
|
if q.dialect.IndexPlaceholders {
|
|
resp, _ = convertQuestionMarks(havingBuf.String(), argsLen+1)
|
|
} else {
|
|
resp = havingBuf.String()
|
|
}
|
|
fmt.Fprintf(buf, resp)
|
|
strmangle.PutBuffer(havingBuf)
|
|
}
|
|
|
|
if len(q.orderBy) != 0 {
|
|
buf.WriteString(" ORDER BY ")
|
|
buf.WriteString(strings.Join(q.orderBy, ", "))
|
|
}
|
|
|
|
if q.limit != 0 {
|
|
fmt.Fprintf(buf, " LIMIT %d", q.limit)
|
|
}
|
|
if q.offset != 0 {
|
|
fmt.Fprintf(buf, " OFFSET %d", q.offset)
|
|
}
|
|
|
|
if len(q.forlock) != 0 {
|
|
fmt.Fprintf(buf, " FOR %s", q.forlock)
|
|
}
|
|
}
|
|
|
|
func writeStars(q *Query) []string {
|
|
cols := make([]string, len(q.from))
|
|
for i, f := range q.from {
|
|
toks := strings.Split(f, " ")
|
|
if len(toks) == 1 {
|
|
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, toks[0]))
|
|
continue
|
|
}
|
|
|
|
alias, name, ok := parseFromClause(toks)
|
|
if !ok {
|
|
return nil
|
|
}
|
|
|
|
if len(alias) != 0 {
|
|
name = alias
|
|
}
|
|
cols[i] = fmt.Sprintf(`%s.*`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, name))
|
|
}
|
|
|
|
return cols
|
|
}
|
|
|
|
func writeAsStatements(q *Query) []string {
|
|
cols := make([]string, len(q.selectCols))
|
|
for i, col := range q.selectCols {
|
|
if !rgxIdentifier.MatchString(col) {
|
|
cols[i] = col
|
|
continue
|
|
}
|
|
|
|
toks := strings.Split(col, ".")
|
|
if len(toks) == 1 {
|
|
cols[i] = strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col)
|
|
continue
|
|
}
|
|
|
|
asParts := make([]string, len(toks))
|
|
for j, tok := range toks {
|
|
asParts[j] = strings.Trim(tok, `"`)
|
|
}
|
|
|
|
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(q.dialect.LQ, q.dialect.RQ, col), strings.Join(asParts, "."))
|
|
}
|
|
|
|
return cols
|
|
}
|
|
|
|
// whereClause parses a where slice and converts it into a
|
|
// single WHERE clause like:
|
|
// WHERE (a=$1) AND (b=$2)
|
|
//
|
|
// startAt specifies what number placeholders start at
|
|
func whereClause(q *Query, startAt int) (string, []interface{}) {
|
|
if len(q.where) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
buf := strmangle.GetBuffer()
|
|
defer strmangle.PutBuffer(buf)
|
|
var args []interface{}
|
|
|
|
buf.WriteString(" WHERE ")
|
|
for i, where := range q.where {
|
|
if i != 0 {
|
|
if where.orSeparator {
|
|
buf.WriteString(" OR ")
|
|
} else {
|
|
buf.WriteString(" AND ")
|
|
}
|
|
}
|
|
|
|
buf.WriteString(fmt.Sprintf("(%s)", where.clause))
|
|
args = append(args, where.args...)
|
|
}
|
|
|
|
var resp string
|
|
if q.dialect.IndexPlaceholders {
|
|
resp, _ = convertQuestionMarks(buf.String(), startAt)
|
|
} else {
|
|
resp = buf.String()
|
|
}
|
|
|
|
return resp, args
|
|
}
|
|
|
|
// inClause parses an in slice and converts it into a
|
|
// single IN clause, like:
|
|
// WHERE ("a", "b") IN (($1,$2),($3,$4)).
|
|
func inClause(q *Query, startAt int) (string, []interface{}) {
|
|
if len(q.in) == 0 {
|
|
return "", nil
|
|
}
|
|
|
|
buf := strmangle.GetBuffer()
|
|
defer strmangle.PutBuffer(buf)
|
|
var args []interface{}
|
|
|
|
if len(q.where) == 0 {
|
|
buf.WriteString(" WHERE ")
|
|
}
|
|
|
|
for i, in := range q.in {
|
|
ln := len(in.args)
|
|
// We only prefix the OR and AND separators after the first
|
|
// clause has been generated UNLESS there is already a where
|
|
// clause that we have to add on to.
|
|
if i != 0 || len(q.where) > 0 {
|
|
if in.orSeparator {
|
|
buf.WriteString(" OR ")
|
|
} else {
|
|
buf.WriteString(" AND ")
|
|
}
|
|
}
|
|
|
|
matches := rgxInClause.FindStringSubmatch(in.clause)
|
|
// If we can't find any matches attempt a simple replace with 1 group.
|
|
// Clauses that fit this criteria will not be able to contain ? in their
|
|
// column name side, however if this case is being hit then the regexp
|
|
// probably needs adjustment, or the user is passing in invalid clauses.
|
|
if matches == nil {
|
|
clause, count := convertInQuestionMarks(q.dialect.IndexPlaceholders, in.clause, startAt, 1, ln)
|
|
buf.WriteString(clause)
|
|
startAt = startAt + count
|
|
} else {
|
|
leftSide := strings.TrimSpace(matches[1])
|
|
rightSide := strings.TrimSpace(matches[2])
|
|
// If matches are found, we have to parse the left side (column side)
|
|
// of the clause to determine how many columns they are using.
|
|
// This number determines the groupAt for the convert function.
|
|
cols := strings.Split(leftSide, ",")
|
|
cols = strmangle.IdentQuoteSlice(q.dialect.LQ, q.dialect.RQ, cols)
|
|
groupAt := len(cols)
|
|
|
|
var leftClause string
|
|
var leftCount int
|
|
if q.dialect.IndexPlaceholders {
|
|
leftClause, leftCount = convertQuestionMarks(strings.Join(cols, ","), startAt)
|
|
} else {
|
|
// Count the number of cols that are question marks, so we know
|
|
// how much to offset convertInQuestionMarks by
|
|
for _, v := range cols {
|
|
if v == "?" {
|
|
leftCount++
|
|
}
|
|
}
|
|
leftClause = strings.Join(cols, ",")
|
|
}
|
|
rightClause, rightCount := convertInQuestionMarks(q.dialect.IndexPlaceholders, rightSide, startAt+leftCount, groupAt, ln-leftCount)
|
|
buf.WriteString(leftClause)
|
|
buf.WriteString(" IN ")
|
|
buf.WriteString(rightClause)
|
|
startAt = startAt + leftCount + rightCount
|
|
}
|
|
|
|
args = append(args, in.args...)
|
|
}
|
|
|
|
return buf.String(), args
|
|
}
|
|
|
|
// convertInQuestionMarks finds the first unescaped occurrence of ? and swaps it
|
|
// with a list of numbered placeholders, starting at startAt.
|
|
// It uses groupAt to determine how many placeholders should be in each group,
|
|
// for example, groupAt 2 would result in: (($1,$2),($3,$4))
|
|
// and groupAt 1 would result in ($1,$2,$3,$4)
|
|
func convertInQuestionMarks(indexPlaceholders bool, clause string, startAt, groupAt, total int) (string, int) {
|
|
if startAt == 0 || len(clause) == 0 {
|
|
panic("Not a valid start number.")
|
|
}
|
|
|
|
paramBuf := strmangle.GetBuffer()
|
|
defer strmangle.PutBuffer(paramBuf)
|
|
|
|
foundAt := -1
|
|
for i := 0; i < len(clause); i++ {
|
|
if (clause[i] == '?' && i == 0) || (clause[i] == '?' && clause[i-1] != '\\') {
|
|
foundAt = i
|
|
break
|
|
}
|
|
}
|
|
|
|
if foundAt == -1 {
|
|
return strings.Replace(clause, `\?`, "?", -1), 0
|
|
}
|
|
|
|
paramBuf.WriteString(clause[:foundAt])
|
|
paramBuf.WriteByte('(')
|
|
paramBuf.WriteString(strmangle.Placeholders(indexPlaceholders, total, startAt, groupAt))
|
|
paramBuf.WriteByte(')')
|
|
paramBuf.WriteString(clause[foundAt+1:])
|
|
|
|
// Remove all backslashes from escaped question-marks
|
|
ret := strings.Replace(paramBuf.String(), `\?`, "?", -1)
|
|
return ret, total
|
|
}
|
|
|
|
// convertQuestionMarks converts each occurrence of ? with $<number>
|
|
// where <number> is an incrementing digit starting at startAt.
|
|
// If question-mark (?) is escaped using back-slash (\), it will be ignored.
|
|
func convertQuestionMarks(clause string, startAt int) (string, int) {
|
|
if startAt == 0 {
|
|
panic("Not a valid start number.")
|
|
}
|
|
|
|
paramBuf := strmangle.GetBuffer()
|
|
defer strmangle.PutBuffer(paramBuf)
|
|
paramIndex := 0
|
|
total := 0
|
|
|
|
for {
|
|
if paramIndex >= len(clause) {
|
|
break
|
|
}
|
|
|
|
clause = clause[paramIndex:]
|
|
paramIndex = strings.IndexByte(clause, '?')
|
|
|
|
if paramIndex == -1 {
|
|
paramBuf.WriteString(clause)
|
|
break
|
|
}
|
|
|
|
escapeIndex := strings.Index(clause, `\?`)
|
|
if escapeIndex != -1 && paramIndex > escapeIndex {
|
|
paramBuf.WriteString(clause[:escapeIndex] + "?")
|
|
paramIndex++
|
|
continue
|
|
}
|
|
|
|
paramBuf.WriteString(clause[:paramIndex] + fmt.Sprintf("$%d", startAt))
|
|
total++
|
|
startAt++
|
|
paramIndex++
|
|
}
|
|
|
|
return paramBuf.String(), total
|
|
}
|
|
|
|
// parseFromClause will parse something that looks like
|
|
// a
|
|
// a b
|
|
// a as b
|
|
func parseFromClause(toks []string) (alias, name string, ok bool) {
|
|
if len(toks) > 3 {
|
|
toks = toks[:3]
|
|
}
|
|
|
|
sawIdent, sawAs := false, false
|
|
for _, tok := range toks {
|
|
if t := strings.ToLower(tok); sawIdent && t == "as" {
|
|
sawAs = true
|
|
continue
|
|
} else if sawIdent && t == "on" {
|
|
break
|
|
}
|
|
|
|
if !rgxIdentifier.MatchString(tok) {
|
|
break
|
|
}
|
|
|
|
if sawIdent || sawAs {
|
|
alias = strings.Trim(tok, `"`)
|
|
break
|
|
}
|
|
|
|
name = strings.Trim(tok, `"`)
|
|
sawIdent = true
|
|
ok = true
|
|
}
|
|
|
|
return alias, name, ok
|
|
}
|