Change join structure to truly support any join

- Also normalize the "," vs ", " in query generation
This commit is contained in:
Aaron L 2016-08-07 13:37:51 -07:00
parent 162746526c
commit 6596868cb8
7 changed files with 128 additions and 99 deletions

View file

@ -1 +1 @@
SELECT count(*) as ab, thing as bd, "stuff" FROM "t";
SELECT count(*) as ab, thing as bd,"stuff" FROM "t";

View file

@ -1 +1 @@
SELECT count(*) as ab, thing as bd, "stuff" FROM "a","b";
SELECT count(*) as ab, thing as bd,"stuff" FROM "a","b";

View file

@ -48,9 +48,9 @@ func Offset(offset int) QueryMod {
}
// InnerJoin on another table
func InnerJoin(stmt string, args ...interface{}) QueryMod {
func InnerJoin(clause string, args ...interface{}) QueryMod {
return func(q *boil.Query) {
boil.AppendInnerJoin(q, stmt, args...)
boil.AppendInnerJoin(q, boil.JoinInner, clause, args...)
}
}

View file

@ -5,6 +5,17 @@ import (
"fmt"
)
// joinKind is the type of join
type joinKind int
// Join type constants
const (
JoinInner joinKind = iota
JoinOuterLeft
JoinOuterRight
JoinNatural
)
// Query holds the state for the built up query
type Query struct {
executor Executor
@ -14,7 +25,7 @@ type Query struct {
selectCols []string
modFunction string
from []string
innerJoins []join
joins []join
where []where
groupBy []string
orderBy []string
@ -35,8 +46,9 @@ type plainSQL struct {
}
type join struct {
on string
args []interface{}
kind joinKind
clause string
args []interface{}
}
// ExecStatement executes a query that does not need a row returned
@ -130,13 +142,13 @@ func SetFrom(q *Query, from ...string) {
}
// AppendInnerJoin on the query.
func AppendInnerJoin(q *Query, on string, args ...interface{}) {
q.innerJoins = append(q.innerJoins, join{on: on, args: args})
func AppendInnerJoin(q *Query, clause string, args ...interface{}) {
q.joins = append(q.joins, join{clause: clause, kind: JoinInner, args: args})
}
// SetInnerJoin on the query.
func SetInnerJoin(q *Query, on string, args ...interface{}) {
q.innerJoins = append([]join(nil), join{on: on, args: args})
func SetInnerJoin(q *Query, clause string, args ...interface{}) {
q.joins = append([]join(nil), join{clause: clause, kind: JoinInner, args: args})
}
// AppendWhere on the query.

View file

@ -44,10 +44,16 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
}
hasSelectCols := len(q.selectCols) != 0
if len(q.innerJoins) != 0 && hasSelectCols && !hasModFunc {
writeComplexSelect(q, buf)
hasJoins := len(q.joins) != 0
if hasJoins && !hasModFunc {
selectColsWithAs := writeAsStatements(q)
// Don't identQuoteSlice - writeAsStatements does this
buf.WriteString(strings.Join(selectColsWithAs, `,`))
} else if hasSelectCols {
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `, `))
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `,`))
} else if hasJoins {
selectColsWithStars := writeStars(q)
buf.WriteString(strings.Join(selectColsWithStars, `,`))
} else {
buf.WriteByte('*')
}
@ -56,14 +62,21 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
buf.WriteString(")")
}
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ","))
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), `,`))
for _, j := range q.joins {
if j.kind != JoinInner {
panic("only inner joins are supported")
}
fmt.Fprintf(buf, " INNER JOIN %s", j.clause)
}
where, args := whereClause(q)
buf.WriteString(where)
if len(q.orderBy) != 0 {
buf.WriteString(" ORDER BY ")
buf.WriteString(strings.Join(q.orderBy, ","))
buf.WriteString(strings.Join(q.orderBy, `,`))
}
if q.limit != 0 {
@ -147,13 +160,9 @@ func identifierMapping(q *Query) map[string]string {
parseIdentifierClause(tokens, setID)
}
for _, join := range q.innerJoins {
tokens := strings.Split(join.on, " ")
discard := 0
for rgxJoinIdentifiers.MatchString(tokens[discard]) {
discard++
}
parseIdentifierClause(tokens[discard:], setID)
for _, join := range q.joins {
tokens := strings.Split(join.clause, " ")
parseIdentifierClause(tokens, setID)
}
return ids

View file

@ -1,6 +1,69 @@
package boil
import "testing"
import (
"bytes"
"flag"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
)
var writeGoldenFiles = flag.Bool(
"test.golden",
false,
"Write golden files.",
)
func TestBuildQuery(t *testing.T) {
t.Parallel()
tests := []struct {
q *Query
args []interface{}
}{
{&Query{from: []string{"t"}}, nil},
{&Query{from: []string{"q"}, limit: 5, offset: 6}, nil},
{&Query{from: []string{"q"}, orderBy: []string{"a ASC", "b DESC"}}, nil},
{&Query{from: []string{"t"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
{&Query{from: []string{"a", "b"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
{&Query{
selectCols: []string{"a.happy", "r.fun", "q"},
from: []string{"happiness as a"},
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
}, nil},
}
for i, test := range tests {
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
out, args := buildQuery(test.q)
if *writeGoldenFiles {
err := ioutil.WriteFile(filename, []byte(out), 0664)
if err != nil {
t.Fatalf("Failed to write golden file %s: %s\n", filename, err)
}
t.Logf("wrote golden file: %s\n", filename)
continue
}
byt, err := ioutil.ReadFile(filename)
if err != nil {
t.Fatalf("Failed to read golden file %q: %v", filename, err)
}
if string(bytes.TrimSpace(byt)) != out {
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, byt, out)
}
if !reflect.DeepEqual(args, test.args) {
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, spew.Sdump(test.args), spew.Sdump(args))
}
}
}
func TestIdentifierMapping(t *testing.T) {
t.Parallel()
@ -26,19 +89,19 @@ func TestIdentifierMapping(t *testing.T) {
Out: map[string]string{"b": "a", "d": "c"},
},
{
In: Query{innerJoins: []join{{on: `inner join a on stuff = there`}}},
In: Query{joins: []join{{kind: JoinInner, clause: `a on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `outer join "a" on stuff = there`}}},
In: Query{joins: []join{{kind: JoinNatural, clause: `"a" on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `natural join a as b on stuff = there`}}},
In: Query{joins: []join{{kind: JoinNatural, clause: `a as b on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
{
In: Query{innerJoins: []join{{on: `right outer join "a" as "b" on stuff = there`}}},
In: Query{joins: []join{{kind: JoinOuterRight, clause: `"a" as "b" on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
}

View file

@ -1,66 +1,11 @@
package boil
import (
"bytes"
"database/sql"
"flag"
"fmt"
"io/ioutil"
"path/filepath"
"reflect"
"testing"
"github.com/davecgh/go-spew/spew"
)
var writeGoldenFiles = flag.Bool(
"test.golden",
false,
"Write golden files.",
)
func TestBuildQuery(t *testing.T) {
t.Parallel()
tests := []struct {
q *Query
args []interface{}
}{
{&Query{from: []string{"t"}}, nil},
{&Query{from: []string{"q"}, limit: 5, offset: 6}, nil},
{&Query{from: []string{"q"}, orderBy: []string{"a ASC", "b DESC"}}, nil},
{&Query{from: []string{"t"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
{&Query{from: []string{"a", "b"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
}
for i, test := range tests {
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
out, args := buildQuery(test.q)
if *writeGoldenFiles {
err := ioutil.WriteFile(filename, []byte(out), 0664)
if err != nil {
t.Fatalf("Failed to write golden file %s: %s\n", filename, err)
}
t.Logf("wrote golden file: %s\n", filename)
continue
}
byt, err := ioutil.ReadFile(filename)
if err != nil {
t.Fatalf("Failed to read golden file %q: %v", filename, err)
}
if string(bytes.TrimSpace(byt)) != out {
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, byt, out)
}
if !reflect.DeepEqual(args, test.args) {
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, spew.Sdump(test.args), spew.Sdump(args))
}
}
}
func TestSetLastWhereAsOr(t *testing.T) {
t.Parallel()
q := &Query{}
@ -301,34 +246,34 @@ func TestInnerJoin(t *testing.T) {
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
if len(q.innerJoins) != 2 {
t.Errorf("Expected len 1, got %d", len(q.innerJoins))
if len(q.joins) != 2 {
t.Errorf("Expected len 1, got %d", len(q.joins))
}
if q.innerJoins[0].on != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
}
if q.innerJoins[1].on != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
if q.joins[1].clause != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
}
if len(q.innerJoins[0].args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.innerJoins[0].args))
if len(q.joins[0].args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.joins[0].args))
}
if len(q.innerJoins[1].args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.innerJoins[1].args))
if len(q.joins[1].args) != 2 {
t.Errorf("Expected len 2, got %d", len(q.joins[1].args))
}
if q.innerJoins[0].args[0] != 2 && q.innerJoins[0].args[1] != 5 {
t.Errorf("Invalid args values, got %#v", q.innerJoins[0].args)
if q.joins[0].args[0] != 2 && q.joins[0].args[1] != 5 {
t.Errorf("Invalid args values, got %#v", q.joins[0].args)
}
SetInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
if len(q.innerJoins) != 1 {
t.Errorf("Expected len 1, got %d", len(q.innerJoins))
if len(q.joins) != 1 {
t.Errorf("Expected len 1, got %d", len(q.joins))
}
if q.innerJoins[0].on != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
}
}