From 5541b4dce9dc3767f16389dcc6415558e4261668 Mon Sep 17 00:00:00 2001 From: Aaron L Date: Sat, 6 Aug 2016 14:42:22 -0700 Subject: [PATCH] Reorganize things before tearing apart --- boil/query.go | 220 +++--------------------------------- boil/query_builders.go | 196 ++++++++++++++++++++++++++++++++ boil/query_builders_test.go | 59 ++++++++++ boil/query_test.go | 56 --------- boil/reflect.go | 5 +- 5 files changed, 272 insertions(+), 264 deletions(-) create mode 100644 boil/query_builders.go create mode 100644 boil/query_builders_test.go diff --git a/boil/query.go b/boil/query.go index febb61a..9c865dd 100644 --- a/boil/query.go +++ b/boil/query.go @@ -1,31 +1,10 @@ package boil import ( - "bytes" "database/sql" "fmt" - "regexp" - "strings" - - "github.com/nullbio/sqlboiler/strmangle" ) -type where struct { - clause string - orSeperator bool - args []interface{} -} - -type plainSQL struct { - sql string - args []interface{} -} - -type join struct { - on string - args []interface{} -} - // Query holds the state for the built up query type Query struct { executor Executor @@ -44,87 +23,24 @@ type Query struct { offset int } -func buildQuery(q *Query) (string, []interface{}) { - var buf *bytes.Buffer - var args []interface{} - - switch { - case len(q.plainSQL.sql) != 0: - return q.plainSQL.sql, q.plainSQL.args - case q.delete: - buf, args = buildDeleteQuery(q) - case len(q.update) > 0: - buf, args = buildUpdateQuery(q) - default: - buf, args = buildSelectQuery(q) - } - - return buf.String(), args +type where struct { + clause string + orSeperator bool + args []interface{} } -func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { - buf := &bytes.Buffer{} - - buf.WriteString("SELECT ") - - if q.count { - buf.WriteString("COUNT(") - } - if len(q.selectCols) > 0 { - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `, `)) - } else { - buf.WriteByte('*') - } - // close sql COUNT function - if q.count { - buf.WriteString(")") - } - - buf.WriteString(" FROM ") - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ",")) - - where, args := whereClause(q) - buf.WriteString(where) - - if len(q.orderBy) != 0 { - buf.WriteString(" ORDER BY ") - buf.WriteString(strings.Join(q.orderBy, ",")) - } - - if q.limit != 0 { - fmt.Fprintf(buf, " LIMIT %d", q.limit) - } - if q.offset != 0 { - fmt.Fprintf(buf, " OFFSET %d", q.offset) - } - - buf.WriteByte(';') - return buf, args +type plainSQL struct { + sql string + args []interface{} } -func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { - buf := &bytes.Buffer{} - - buf.WriteString("DELETE FROM ") - buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ",")) - - where, args := whereClause(q) - buf.WriteString(where) - - buf.WriteByte(';') - - return buf, args +type join struct { + on string + args []interface{} } -func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { - buf := &bytes.Buffer{} - - buf.WriteByte(';') - return buf, nil -} - -// ExecQuery executes a query that does not need a row returned -func ExecQuery(q *Query) (sql.Result, error) { +// ExecStatement executes a query that does not need a row returned +func ExecStatement(q *Query) (sql.Result, error) { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) @@ -132,17 +48,8 @@ func ExecQuery(q *Query) (sql.Result, error) { return q.executor.Exec(qs, args...) } -// ExecQueryOne executes the query for the One finisher and returns a row -func ExecQueryOne(q *Query) *sql.Row { - qs, args := buildQuery(q) - if DebugMode { - fmt.Fprintln(DebugWriter, qs) - } - return q.executor.QueryRow(qs, args...) -} - -// ExecQueryAll executes the query for the All finisher and returns multiple rows -func ExecQueryAll(q *Query) (*sql.Rows, error) { +// ExecQuery executes the query for the All finisher and returns multiple rows +func ExecQuery(q *Query) (*sql.Rows, error) { qs, args := buildQuery(q) if DebugMode { fmt.Fprintln(DebugWriter, qs) @@ -266,102 +173,3 @@ func SetLimit(q *Query, limit int) { func SetOffset(q *Query, offset int) { q.offset = offset } - -func whereClause(q *Query) (string, []interface{}) { - if len(q.where) == 0 { - return "", nil - } - - buf := &bytes.Buffer{} - var args []interface{} - - buf.WriteString(" WHERE ") - for i := 0; i < len(q.where); i++ { - buf.WriteString(fmt.Sprintf("%s", q.where[i].clause)) - args = append(args, q.where[i].args...) - if i >= len(q.where)-1 { - continue - } - if q.where[i].orSeperator { - buf.WriteString(" OR ") - } else { - buf.WriteString(" AND ") - } - } - - return buf.String(), args -} - -var ( - rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`) - rgxJoinIdentifiers = regexp.MustCompile(`^(?i)(?:join|inner|natural|outer|left|right)$`) -) - -// identifierMapping creates a map of all identifiers to potential model names -func identifierMapping(q *Query) map[string]string { - var ids map[string]string - setID := func(alias, name string) { - if ids == nil { - ids = make(map[string]string) - } - ids[alias] = name - } - - for _, from := range q.from { - tokens := strings.Split(from, " ") - parseIdentifierClause(tokens, setID) - } - - for _, join := range q.innerJoins { - tokens := strings.Split(join.on, " ") - discard := 0 - for rgxJoinIdentifiers.MatchString(tokens[discard]) { - discard++ - } - parseIdentifierClause(tokens[discard:], setID) - } - - return ids -} - -// parseBits takes a set of tokens and looks for something of the form: -// a b -// a as b -// where 'a' and 'b' are valid SQL identifiers -// It only evaluates the first 3 tokens (anything past that is superfluous) -// It stops parsing when it finds "on" or an invalid identifier -func parseIdentifierClause(tokens []string, setID func(string, string)) { - var name, alias string - sawIdent, sawAs := false, false - - if len(tokens) > 3 { - tokens = tokens[:3] - } - - for _, tok := range tokens { - if t := strings.ToLower(tok); sawIdent && t == "as" { - sawAs = true - continue - } else if sawIdent && t == "on" { - break - } - - if !rgxIdentifier.MatchString(tok) { - break - } - - if sawIdent || sawAs { - alias = strings.Trim(tok, `"`) - break - } - - name = strings.Trim(tok, `"`) - sawIdent = true - } - - if len(alias) > 0 { - setID(alias, name) - } else { - setID(name, name) - } -} diff --git a/boil/query_builders.go b/boil/query_builders.go new file mode 100644 index 0000000..6a1b023 --- /dev/null +++ b/boil/query_builders.go @@ -0,0 +1,196 @@ +package boil + +import ( + "bytes" + "fmt" + "regexp" + "strings" + + "github.com/nullbio/sqlboiler/strmangle" +) + +var ( + rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`) + rgxJoinIdentifiers = regexp.MustCompile(`^(?i)(?:join|inner|natural|outer|left|right)$`) +) + +func buildQuery(q *Query) (string, []interface{}) { + var buf *bytes.Buffer + var args []interface{} + + switch { + case len(q.plainSQL.sql) != 0: + return q.plainSQL.sql, q.plainSQL.args + case q.delete: + buf, args = buildDeleteQuery(q) + case len(q.update) > 0: + buf, args = buildUpdateQuery(q) + default: + buf, args = buildSelectQuery(q) + } + + return buf.String(), args +} + +func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { + buf := &bytes.Buffer{} + + buf.WriteString("SELECT ") + + if q.count { + buf.WriteString("COUNT(") + } + + if len(q.innerJoins) > 0 && !q.count { + writeComplexSelect(q, buf) + } else { + if len(q.selectCols) > 0 { + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `, `)) + } else { + buf.WriteByte('*') + } + } + + // close sql COUNT function + if q.count { + buf.WriteString(")") + } + + fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ",")) + + where, args := whereClause(q) + buf.WriteString(where) + + if len(q.orderBy) != 0 { + buf.WriteString(" ORDER BY ") + buf.WriteString(strings.Join(q.orderBy, ",")) + } + + if q.limit != 0 { + fmt.Fprintf(buf, " LIMIT %d", q.limit) + } + if q.offset != 0 { + fmt.Fprintf(buf, " OFFSET %d", q.offset) + } + + buf.WriteByte(';') + return buf, args +} + +func writeComplexSelect(q *Query, buf *bytes.Buffer) { +} + +func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { + buf := &bytes.Buffer{} + + buf.WriteString("DELETE FROM ") + buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.from), ",")) + + where, args := whereClause(q) + buf.WriteString(where) + + buf.WriteByte(';') + + return buf, args +} + +func buildUpdateQuery(q *Query) (*bytes.Buffer, []interface{}) { + buf := &bytes.Buffer{} + + buf.WriteByte(';') + return buf, nil +} + +func whereClause(q *Query) (string, []interface{}) { + if len(q.where) == 0 { + return "", nil + } + + buf := &bytes.Buffer{} + var args []interface{} + + buf.WriteString(" WHERE ") + for i := 0; i < len(q.where); i++ { + buf.WriteString(fmt.Sprintf("%s", q.where[i].clause)) + args = append(args, q.where[i].args...) + if i >= len(q.where)-1 { + continue + } + if q.where[i].orSeperator { + buf.WriteString(" OR ") + } else { + buf.WriteString(" AND ") + } + } + + return buf.String(), args +} + +// identifierMapping creates a map of all identifiers to potential model names +func identifierMapping(q *Query) map[string]string { + var ids map[string]string + setID := func(alias, name string) { + if ids == nil { + ids = make(map[string]string) + } + ids[alias] = name + } + + for _, from := range q.from { + tokens := strings.Split(from, " ") + parseIdentifierClause(tokens, setID) + } + + for _, join := range q.innerJoins { + tokens := strings.Split(join.on, " ") + discard := 0 + for rgxJoinIdentifiers.MatchString(tokens[discard]) { + discard++ + } + parseIdentifierClause(tokens[discard:], setID) + } + + return ids +} + +// parseBits takes a set of tokens and looks for something of the form: +// a b +// a as b +// where 'a' and 'b' are valid SQL identifiers +// It only evaluates the first 3 tokens (anything past that is superfluous) +// It stops parsing when it finds "on" or an invalid identifier +func parseIdentifierClause(tokens []string, setID func(string, string)) { + var name, alias string + sawIdent, sawAs := false, false + + if len(tokens) > 3 { + tokens = tokens[:3] + } + + for _, tok := range tokens { + if t := strings.ToLower(tok); sawIdent && t == "as" { + sawAs = true + continue + } else if sawIdent && t == "on" { + break + } + + if !rgxIdentifier.MatchString(tok) { + break + } + + if sawIdent || sawAs { + alias = strings.Trim(tok, `"`) + break + } + + name = strings.Trim(tok, `"`) + sawIdent = true + } + + if len(alias) > 0 { + setID(alias, name) + } else { + setID(name, name) + } +} diff --git a/boil/query_builders_test.go b/boil/query_builders_test.go new file mode 100644 index 0000000..27b4aaf --- /dev/null +++ b/boil/query_builders_test.go @@ -0,0 +1,59 @@ +package boil + +import "testing" + +func TestIdentifierMapping(t *testing.T) { + t.Parallel() + + tests := []struct { + In Query + Out map[string]string + }{ + { + In: Query{from: []string{`a`}}, + Out: map[string]string{"a": "a"}, + }, + { + In: Query{from: []string{`"a"`, `b`}}, + Out: map[string]string{"a": "a", "b": "b"}, + }, + { + In: Query{from: []string{`a as b`}}, + Out: map[string]string{"b": "a"}, + }, + { + In: Query{from: []string{`a AS "b"`, `"c" as d`}}, + Out: map[string]string{"b": "a", "d": "c"}, + }, + { + In: Query{innerJoins: []join{{on: `inner join a on stuff = there`}}}, + Out: map[string]string{"a": "a"}, + }, + { + In: Query{innerJoins: []join{{on: `outer join "a" on stuff = there`}}}, + Out: map[string]string{"a": "a"}, + }, + { + In: Query{innerJoins: []join{{on: `natural join a as b on stuff = there`}}}, + Out: map[string]string{"b": "a"}, + }, + { + In: Query{innerJoins: []join{{on: `right outer join "a" as "b" on stuff = there`}}}, + Out: map[string]string{"b": "a"}, + }, + } + + for i, test := range tests { + m := identifierMapping(&test.In) + + for k, v := range test.Out { + val, ok := m[k] + if !ok { + t.Errorf("%d) want: %s = %s, but was missing", i, k, v) + } + if val != v { + t.Errorf("%d) want: %s = %s, got: %s", i, k, v, val) + } + } + } +} diff --git a/boil/query_test.go b/boil/query_test.go index 79d317c..73e11cc 100644 --- a/boil/query_test.go +++ b/boil/query_test.go @@ -332,59 +332,3 @@ func TestInnerJoin(t *testing.T) { t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins) } } - -func TestIdentifierMapping(t *testing.T) { - t.Parallel() - - tests := []struct { - In Query - Out map[string]string - }{ - { - In: Query{from: []string{`a`}}, - Out: map[string]string{"a": "a"}, - }, - { - In: Query{from: []string{`"a"`, `b`}}, - Out: map[string]string{"a": "a", "b": "b"}, - }, - { - In: Query{from: []string{`a as b`}}, - Out: map[string]string{"b": "a"}, - }, - { - In: Query{from: []string{`a AS "b"`, `"c" as d`}}, - Out: map[string]string{"b": "a", "d": "c"}, - }, - { - In: Query{innerJoins: []join{{on: `inner join a on stuff = there`}}}, - Out: map[string]string{"a": "a"}, - }, - { - In: Query{innerJoins: []join{{on: `outer join "a" on stuff = there`}}}, - Out: map[string]string{"a": "a"}, - }, - { - In: Query{innerJoins: []join{{on: `natural join a as b on stuff = there`}}}, - Out: map[string]string{"b": "a"}, - }, - { - In: Query{innerJoins: []join{{on: `right outer join "a" as "b" on stuff = there`}}}, - Out: map[string]string{"b": "a"}, - }, - } - - for i, test := range tests { - m := identifierMapping(&test.In) - - for k, v := range test.Out { - val, ok := m[k] - if !ok { - t.Errorf("%d) want: %s = %s, but was missing", i, k, v) - } - if val != v { - t.Errorf("%d) want: %s = %s, got: %s", i, k, v, val) - } - } - } -} diff --git a/boil/reflect.go b/boil/reflect.go index 25a633e..662c5a6 100644 --- a/boil/reflect.go +++ b/boil/reflect.go @@ -11,7 +11,8 @@ import ( // Bind executes the query and inserts the // result into the passed in object pointer func (q *Query) Bind(obj interface{}) error { - typ := reflect.TypeOf(obj) + return nil + /*typ := reflect.TypeOf(obj) kind := typ.Kind() if kind != reflect.Ptr { @@ -40,7 +41,7 @@ func (q *Query) Bind(obj interface{}) error { return fmt.Errorf("Bind given a pointer to a non-slice or non-struct: %s", typ.String()) } - return nil + return nil*/ } // BindP executes the query and inserts the