Reorganize things before tearing apart

This commit is contained in:
Aaron L 2016-08-06 14:42:22 -07:00
parent 8298da6f48
commit 5541b4dce9
5 changed files with 272 additions and 264 deletions

View file

@ -1,31 +1,10 @@
package boil package boil
import ( import (
"bytes"
"database/sql" "database/sql"
"fmt" "fmt"
"regexp"
"strings"
"github.com/nullbio/sqlboiler/strmangle"
) )
type where struct {
clause string
orSeperator bool
args []interface{}
}
type plainSQL struct {
sql string
args []interface{}
}
type join struct {
on string
args []interface{}
}
// Query holds the state for the built up query // Query holds the state for the built up query
type Query struct { type Query struct {
executor Executor executor Executor
@ -44,87 +23,24 @@ type Query struct {
offset int offset int
} }
func buildQuery(q *Query) (string, []interface{}) { type where struct {
var buf *bytes.Buffer clause string
var args []interface{} orSeperator bool
args []interface{}
switch {
case len(q.plainSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args
case q.delete:
buf, args = buildDeleteQuery(q)
case len(q.update) > 0:
buf, args = buildUpdateQuery(q)
default:
buf, args = buildSelectQuery(q)
}
return buf.String(), args
} }
func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { type plainSQL struct {
buf := &bytes.Buffer{} sql string
args []interface{}
buf.WriteString("SELECT ")
if q.count {
buf.WriteString("COUNT(")
}
if len(q.selectCols) > 0 {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `, `))
} else {
buf.WriteByte('*')
}
// close sql COUNT function
if q.count {
buf.WriteString(")")
}
buf.WriteString(" FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ","))
where, args := whereClause(q)
buf.WriteString(where)
if len(q.orderBy) != 0 {
buf.WriteString(" ORDER BY ")
buf.WriteString(strings.Join(q.orderBy, ","))
}
if q.limit != 0 {
fmt.Fprintf(buf, " LIMIT %d", q.limit)
}
if q.offset != 0 {
fmt.Fprintf(buf, " OFFSET %d", q.offset)
}
buf.WriteByte(';')
return buf, args
} }
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { type join struct {
buf := &bytes.Buffer{} on string
args []interface{}
buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ","))
where, args := whereClause(q)
buf.WriteString(where)
buf.WriteByte(';')
return buf, args
} }
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { // ExecStatement executes a query that does not need a row returned
buf := &bytes.Buffer{} func ExecStatement(q *Query) (sql.Result, error) {
buf.WriteByte(';')
return buf, nil
}
// ExecQuery executes a query that does not need a row returned
func ExecQuery(q *Query) (sql.Result, error) {
qs, args := buildQuery(q) qs, args := buildQuery(q)
if DebugMode { if DebugMode {
fmt.Fprintln(DebugWriter, qs) fmt.Fprintln(DebugWriter, qs)
@ -132,17 +48,8 @@ func ExecQuery(q *Query) (sql.Result, error) {
return q.executor.Exec(qs, args...) return q.executor.Exec(qs, args...)
} }
// ExecQueryOne executes the query for the One finisher and returns a row // ExecQuery executes the query for the All finisher and returns multiple rows
func ExecQueryOne(q *Query) *sql.Row { func ExecQuery(q *Query) (*sql.Rows, error) {
qs, args := buildQuery(q)
if DebugMode {
fmt.Fprintln(DebugWriter, qs)
}
return q.executor.QueryRow(qs, args...)
}
// ExecQueryAll executes the query for the All finisher and returns multiple rows
func ExecQueryAll(q *Query) (*sql.Rows, error) {
qs, args := buildQuery(q) qs, args := buildQuery(q)
if DebugMode { if DebugMode {
fmt.Fprintln(DebugWriter, qs) fmt.Fprintln(DebugWriter, qs)
@ -266,102 +173,3 @@ func SetLimit(q *Query, limit int) {
func SetOffset(q *Query, offset int) { func SetOffset(q *Query, offset int) {
q.offset = offset q.offset = offset
} }
func whereClause(q *Query) (string, []interface{}) {
if len(q.where) == 0 {
return "", nil
}
buf := &bytes.Buffer{}
var args []interface{}
buf.WriteString(" WHERE ")
for i := 0; i < len(q.where); i++ {
buf.WriteString(fmt.Sprintf("%s", q.where[i].clause))
args = append(args, q.where[i].args...)
if i >= len(q.where)-1 {
continue
}
if q.where[i].orSeperator {
buf.WriteString(" OR ")
} else {
buf.WriteString(" AND ")
}
}
return buf.String(), args
}
var (
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
rgxJoinIdentifiers = regexp.MustCompile(`^(?i)(?:join|inner|natural|outer|left|right)$`)
)
// identifierMapping creates a map of all identifiers to potential model names
func identifierMapping(q *Query) map[string]string {
var ids map[string]string
setID := func(alias, name string) {
if ids == nil {
ids = make(map[string]string)
}
ids[alias] = name
}
for _, from := range q.from {
tokens := strings.Split(from, " ")
parseIdentifierClause(tokens, setID)
}
for _, join := range q.innerJoins {
tokens := strings.Split(join.on, " ")
discard := 0
for rgxJoinIdentifiers.MatchString(tokens[discard]) {
discard++
}
parseIdentifierClause(tokens[discard:], setID)
}
return ids
}
// parseBits takes a set of tokens and looks for something of the form:
// a b
// a as b
// where 'a' and 'b' are valid SQL identifiers
// It only evaluates the first 3 tokens (anything past that is superfluous)
// It stops parsing when it finds "on" or an invalid identifier
func parseIdentifierClause(tokens []string, setID func(string, string)) {
var name, alias string
sawIdent, sawAs := false, false
if len(tokens) > 3 {
tokens = tokens[:3]
}
for _, tok := range tokens {
if t := strings.ToLower(tok); sawIdent && t == "as" {
sawAs = true
continue
} else if sawIdent && t == "on" {
break
}
if !rgxIdentifier.MatchString(tok) {
break
}
if sawIdent || sawAs {
alias = strings.Trim(tok, `"`)
break
}
name = strings.Trim(tok, `"`)
sawIdent = true
}
if len(alias) > 0 {
setID(alias, name)
} else {
setID(name, name)
}
}

196
boil/query_builders.go Normal file
View file

@ -0,0 +1,196 @@
package boil
import (
"bytes"
"fmt"
"regexp"
"strings"
"github.com/nullbio/sqlboiler/strmangle"
)
var (
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
rgxJoinIdentifiers = regexp.MustCompile(`^(?i)(?:join|inner|natural|outer|left|right)$`)
)
func buildQuery(q *Query) (string, []interface{}) {
var buf *bytes.Buffer
var args []interface{}
switch {
case len(q.plainSQL.sql) != 0:
return q.plainSQL.sql, q.plainSQL.args
case q.delete:
buf, args = buildDeleteQuery(q)
case len(q.update) > 0:
buf, args = buildUpdateQuery(q)
default:
buf, args = buildSelectQuery(q)
}
return buf.String(), args
}
func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("SELECT ")
if q.count {
buf.WriteString("COUNT(")
}
if len(q.innerJoins) > 0 && !q.count {
writeComplexSelect(q, buf)
} else {
if len(q.selectCols) > 0 {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `, `))
} else {
buf.WriteByte('*')
}
}
// close sql COUNT function
if q.count {
buf.WriteString(")")
}
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ","))
where, args := whereClause(q)
buf.WriteString(where)
if len(q.orderBy) != 0 {
buf.WriteString(" ORDER BY ")
buf.WriteString(strings.Join(q.orderBy, ","))
}
if q.limit != 0 {
fmt.Fprintf(buf, " LIMIT %d", q.limit)
}
if q.offset != 0 {
fmt.Fprintf(buf, " OFFSET %d", q.offset)
}
buf.WriteByte(';')
return buf, args
}
func writeComplexSelect(q *Query, buf *bytes.Buffer) {
}
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteString("DELETE FROM ")
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ","))
where, args := whereClause(q)
buf.WriteString(where)
buf.WriteByte(';')
return buf, args
}
func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf := &bytes.Buffer{}
buf.WriteByte(';')
return buf, nil
}
func whereClause(q *Query) (string, []interface{}) {
if len(q.where) == 0 {
return "", nil
}
buf := &bytes.Buffer{}
var args []interface{}
buf.WriteString(" WHERE ")
for i := 0; i < len(q.where); i++ {
buf.WriteString(fmt.Sprintf("%s", q.where[i].clause))
args = append(args, q.where[i].args...)
if i >= len(q.where)-1 {
continue
}
if q.where[i].orSeperator {
buf.WriteString(" OR ")
} else {
buf.WriteString(" AND ")
}
}
return buf.String(), args
}
// identifierMapping creates a map of all identifiers to potential model names
func identifierMapping(q *Query) map[string]string {
var ids map[string]string
setID := func(alias, name string) {
if ids == nil {
ids = make(map[string]string)
}
ids[alias] = name
}
for _, from := range q.from {
tokens := strings.Split(from, " ")
parseIdentifierClause(tokens, setID)
}
for _, join := range q.innerJoins {
tokens := strings.Split(join.on, " ")
discard := 0
for rgxJoinIdentifiers.MatchString(tokens[discard]) {
discard++
}
parseIdentifierClause(tokens[discard:], setID)
}
return ids
}
// parseBits takes a set of tokens and looks for something of the form:
// a b
// a as b
// where 'a' and 'b' are valid SQL identifiers
// It only evaluates the first 3 tokens (anything past that is superfluous)
// It stops parsing when it finds "on" or an invalid identifier
func parseIdentifierClause(tokens []string, setID func(string, string)) {
var name, alias string
sawIdent, sawAs := false, false
if len(tokens) > 3 {
tokens = tokens[:3]
}
for _, tok := range tokens {
if t := strings.ToLower(tok); sawIdent && t == "as" {
sawAs = true
continue
} else if sawIdent && t == "on" {
break
}
if !rgxIdentifier.MatchString(tok) {
break
}
if sawIdent || sawAs {
alias = strings.Trim(tok, `"`)
break
}
name = strings.Trim(tok, `"`)
sawIdent = true
}
if len(alias) > 0 {
setID(alias, name)
} else {
setID(name, name)
}
}

View file

@ -0,0 +1,59 @@
package boil
import "testing"
func TestIdentifierMapping(t *testing.T) {
t.Parallel()
tests := []struct {
In Query
Out map[string]string
}{
{
In: Query{from: []string{`a`}},
Out: map[string]string{"a": "a"},
},
{
In: Query{from: []string{`"a"`, `b`}},
Out: map[string]string{"a": "a", "b": "b"},
},
{
In: Query{from: []string{`a as b`}},
Out: map[string]string{"b": "a"},
},
{
In: Query{from: []string{`a AS "b"`, `"c" as d`}},
Out: map[string]string{"b": "a", "d": "c"},
},
{
In: Query{innerJoins: []join{{on: `inner join a on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `outer join "a" on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `natural join a as b on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
{
In: Query{innerJoins: []join{{on: `right outer join "a" as "b" on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
}
for i, test := range tests {
m := identifierMapping(&test.In)
for k, v := range test.Out {
val, ok := m[k]
if !ok {
t.Errorf("%d) want: %s = %s, but was missing", i, k, v)
}
if val != v {
t.Errorf("%d) want: %s = %s, got: %s", i, k, v, val)
}
}
}
}

View file

@ -332,59 +332,3 @@ func TestInnerJoin(t *testing.T) {
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins) t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
} }
} }
func TestIdentifierMapping(t *testing.T) {
t.Parallel()
tests := []struct {
In Query
Out map[string]string
}{
{
In: Query{from: []string{`a`}},
Out: map[string]string{"a": "a"},
},
{
In: Query{from: []string{`"a"`, `b`}},
Out: map[string]string{"a": "a", "b": "b"},
},
{
In: Query{from: []string{`a as b`}},
Out: map[string]string{"b": "a"},
},
{
In: Query{from: []string{`a AS "b"`, `"c" as d`}},
Out: map[string]string{"b": "a", "d": "c"},
},
{
In: Query{innerJoins: []join{{on: `inner join a on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `outer join "a" on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `natural join a as b on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
{
In: Query{innerJoins: []join{{on: `right outer join "a" as "b" on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
}
for i, test := range tests {
m := identifierMapping(&test.In)
for k, v := range test.Out {
val, ok := m[k]
if !ok {
t.Errorf("%d) want: %s = %s, but was missing", i, k, v)
}
if val != v {
t.Errorf("%d) want: %s = %s, got: %s", i, k, v, val)
}
}
}
}

View file

@ -11,7 +11,8 @@ import (
// Bind executes the query and inserts the // Bind executes the query and inserts the
// result into the passed in object pointer // result into the passed in object pointer
func (q *Query) Bind(obj interface{}) error { func (q *Query) Bind(obj interface{}) error {
typ := reflect.TypeOf(obj) return nil
/*typ := reflect.TypeOf(obj)
kind := typ.Kind() kind := typ.Kind()
if kind != reflect.Ptr { if kind != reflect.Ptr {
@ -40,7 +41,7 @@ func (q *Query) Bind(obj interface{}) error {
return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String()) return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String())
} }
return nil return nil*/
} }
// BindP executes the query and inserts the // BindP executes the query and inserts the