Change join structure to truly support any join
- Also normalize the "," vs ", " in query generation
This commit is contained in:
parent
162746526c
commit
6596868cb8
7 changed files with 128 additions and 99 deletions
|
@ -1 +1 @@
|
||||||
SELECT count(*) as ab, thing as bd, "stuff" FROM "t";
|
SELECT count(*) as ab, thing as bd,"stuff" FROM "t";
|
|
@ -1 +1 @@
|
||||||
SELECT count(*) as ab, thing as bd, "stuff" FROM "a","b";
|
SELECT count(*) as ab, thing as bd,"stuff" FROM "a","b";
|
|
@ -48,9 +48,9 @@ func Offset(offset int) QueryMod {
|
||||||
}
|
}
|
||||||
|
|
||||||
// InnerJoin on another table
|
// InnerJoin on another table
|
||||||
func InnerJoin(stmt string, args ...interface{}) QueryMod {
|
func InnerJoin(clause string, args ...interface{}) QueryMod {
|
||||||
return func(q *boil.Query) {
|
return func(q *boil.Query) {
|
||||||
boil.AppendInnerJoin(q, stmt, args...)
|
boil.AppendInnerJoin(q, boil.JoinInner, clause, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,17 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// joinKind is the type of join
|
||||||
|
type joinKind int
|
||||||
|
|
||||||
|
// Join type constants
|
||||||
|
const (
|
||||||
|
JoinInner joinKind = iota
|
||||||
|
JoinOuterLeft
|
||||||
|
JoinOuterRight
|
||||||
|
JoinNatural
|
||||||
|
)
|
||||||
|
|
||||||
// Query holds the state for the built up query
|
// Query holds the state for the built up query
|
||||||
type Query struct {
|
type Query struct {
|
||||||
executor Executor
|
executor Executor
|
||||||
|
@ -14,7 +25,7 @@ type Query struct {
|
||||||
selectCols []string
|
selectCols []string
|
||||||
modFunction string
|
modFunction string
|
||||||
from []string
|
from []string
|
||||||
innerJoins []join
|
joins []join
|
||||||
where []where
|
where []where
|
||||||
groupBy []string
|
groupBy []string
|
||||||
orderBy []string
|
orderBy []string
|
||||||
|
@ -35,8 +46,9 @@ type plainSQL struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type join struct {
|
type join struct {
|
||||||
on string
|
kind joinKind
|
||||||
args []interface{}
|
clause string
|
||||||
|
args []interface{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ExecStatement executes a query that does not need a row returned
|
// ExecStatement executes a query that does not need a row returned
|
||||||
|
@ -130,13 +142,13 @@ func SetFrom(q *Query, from ...string) {
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppendInnerJoin on the query.
|
// AppendInnerJoin on the query.
|
||||||
func AppendInnerJoin(q *Query, on string, args ...interface{}) {
|
func AppendInnerJoin(q *Query, clause string, args ...interface{}) {
|
||||||
q.innerJoins = append(q.innerJoins, join{on: on, args: args})
|
q.joins = append(q.joins, join{clause: clause, kind: JoinInner, args: args})
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetInnerJoin on the query.
|
// SetInnerJoin on the query.
|
||||||
func SetInnerJoin(q *Query, on string, args ...interface{}) {
|
func SetInnerJoin(q *Query, clause string, args ...interface{}) {
|
||||||
q.innerJoins = append([]join(nil), join{on: on, args: args})
|
q.joins = append([]join(nil), join{clause: clause, kind: JoinInner, args: args})
|
||||||
}
|
}
|
||||||
|
|
||||||
// AppendWhere on the query.
|
// AppendWhere on the query.
|
||||||
|
|
|
@ -44,10 +44,16 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
hasSelectCols := len(q.selectCols) != 0
|
hasSelectCols := len(q.selectCols) != 0
|
||||||
if len(q.innerJoins) != 0 && hasSelectCols && !hasModFunc {
|
hasJoins := len(q.joins) != 0
|
||||||
writeComplexSelect(q, buf)
|
if hasJoins && !hasModFunc {
|
||||||
|
selectColsWithAs := writeAsStatements(q)
|
||||||
|
// Don't identQuoteSlice - writeAsStatements does this
|
||||||
|
buf.WriteString(strings.Join(selectColsWithAs, `,`))
|
||||||
} else if hasSelectCols {
|
} else if hasSelectCols {
|
||||||
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `, `))
|
buf.WriteString(strings.Join(strmangle.IdentQuoteSlice(q.selectCols), `,`))
|
||||||
|
} else if hasJoins {
|
||||||
|
selectColsWithStars := writeStars(q)
|
||||||
|
buf.WriteString(strings.Join(selectColsWithStars, `,`))
|
||||||
} else {
|
} else {
|
||||||
buf.WriteByte('*')
|
buf.WriteByte('*')
|
||||||
}
|
}
|
||||||
|
@ -56,14 +62,21 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||||
buf.WriteString(")")
|
buf.WriteString(")")
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), ","))
|
fmt.Fprintf(buf, " FROM %s", strings.Join(strmangle.IdentQuoteSlice(q.from), `,`))
|
||||||
|
|
||||||
|
for _, j := range q.joins {
|
||||||
|
if j.kind != JoinInner {
|
||||||
|
panic("only inner joins are supported")
|
||||||
|
}
|
||||||
|
fmt.Fprintf(buf, " INNER JOIN %s", j.clause)
|
||||||
|
}
|
||||||
|
|
||||||
where, args := whereClause(q)
|
where, args := whereClause(q)
|
||||||
buf.WriteString(where)
|
buf.WriteString(where)
|
||||||
|
|
||||||
if len(q.orderBy) != 0 {
|
if len(q.orderBy) != 0 {
|
||||||
buf.WriteString(" ORDER BY ")
|
buf.WriteString(" ORDER BY ")
|
||||||
buf.WriteString(strings.Join(q.orderBy, ","))
|
buf.WriteString(strings.Join(q.orderBy, `,`))
|
||||||
}
|
}
|
||||||
|
|
||||||
if q.limit != 0 {
|
if q.limit != 0 {
|
||||||
|
@ -147,13 +160,9 @@ func identifierMapping(q *Query) map[string]string {
|
||||||
parseIdentifierClause(tokens, setID)
|
parseIdentifierClause(tokens, setID)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, join := range q.innerJoins {
|
for _, join := range q.joins {
|
||||||
tokens := strings.Split(join.on, " ")
|
tokens := strings.Split(join.clause, " ")
|
||||||
discard := 0
|
parseIdentifierClause(tokens, setID)
|
||||||
for rgxJoinIdentifiers.MatchString(tokens[discard]) {
|
|
||||||
discard++
|
|
||||||
}
|
|
||||||
parseIdentifierClause(tokens[discard:], setID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return ids
|
return ids
|
||||||
|
|
|
@ -1,6 +1,69 @@
|
||||||
package boil
|
package boil
|
||||||
|
|
||||||
import "testing"
|
import (
|
||||||
|
"bytes"
|
||||||
|
"flag"
|
||||||
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
|
"path/filepath"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/davecgh/go-spew/spew"
|
||||||
|
)
|
||||||
|
|
||||||
|
var writeGoldenFiles = flag.Bool(
|
||||||
|
"test.golden",
|
||||||
|
false,
|
||||||
|
"Write golden files.",
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildQuery(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
q *Query
|
||||||
|
args []interface{}
|
||||||
|
}{
|
||||||
|
{&Query{from: []string{"t"}}, nil},
|
||||||
|
{&Query{from: []string{"q"}, limit: 5, offset: 6}, nil},
|
||||||
|
{&Query{from: []string{"q"}, orderBy: []string{"a ASC", "b DESC"}}, nil},
|
||||||
|
{&Query{from: []string{"t"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
|
||||||
|
{&Query{from: []string{"a", "b"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
|
||||||
|
{&Query{
|
||||||
|
selectCols: []string{"a.happy", "r.fun", "q"},
|
||||||
|
from: []string{"happiness as a"},
|
||||||
|
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
|
||||||
|
}, nil},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, test := range tests {
|
||||||
|
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
|
||||||
|
out, args := buildQuery(test.q)
|
||||||
|
|
||||||
|
if *writeGoldenFiles {
|
||||||
|
err := ioutil.WriteFile(filename, []byte(out), 0664)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to write golden file %s: %s\n", filename, err)
|
||||||
|
}
|
||||||
|
t.Logf("wrote golden file: %s\n", filename)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
byt, err := ioutil.ReadFile(filename)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to read golden file %q: %v", filename, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(bytes.TrimSpace(byt)) != out {
|
||||||
|
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, byt, out)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !reflect.DeepEqual(args, test.args) {
|
||||||
|
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, spew.Sdump(test.args), spew.Sdump(args))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestIdentifierMapping(t *testing.T) {
|
func TestIdentifierMapping(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
@ -26,19 +89,19 @@ func TestIdentifierMapping(t *testing.T) {
|
||||||
Out: map[string]string{"b": "a", "d": "c"},
|
Out: map[string]string{"b": "a", "d": "c"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
In: Query{innerJoins: []join{{on: `inner join a on stuff = there`}}},
|
In: Query{joins: []join{{kind: JoinInner, clause: `a on stuff = there`}}},
|
||||||
Out: map[string]string{"a": "a"},
|
Out: map[string]string{"a": "a"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
In: Query{innerJoins: []join{{on: `outer join "a" on stuff = there`}}},
|
In: Query{joins: []join{{kind: JoinNatural, clause: `"a" on stuff = there`}}},
|
||||||
Out: map[string]string{"a": "a"},
|
Out: map[string]string{"a": "a"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
In: Query{innerJoins: []join{{on: `natural join a as b on stuff = there`}}},
|
In: Query{joins: []join{{kind: JoinNatural, clause: `a as b on stuff = there`}}},
|
||||||
Out: map[string]string{"b": "a"},
|
Out: map[string]string{"b": "a"},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
In: Query{innerJoins: []join{{on: `right outer join "a" as "b" on stuff = there`}}},
|
In: Query{joins: []join{{kind: JoinOuterRight, clause: `"a" as "b" on stuff = there`}}},
|
||||||
Out: map[string]string{"b": "a"},
|
Out: map[string]string{"b": "a"},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,66 +1,11 @@
|
||||||
package boil
|
package boil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"path/filepath"
|
|
||||||
"reflect"
|
"reflect"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var writeGoldenFiles = flag.Bool(
|
|
||||||
"test.golden",
|
|
||||||
false,
|
|
||||||
"Write golden files.",
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBuildQuery(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
q *Query
|
|
||||||
args []interface{}
|
|
||||||
}{
|
|
||||||
{&Query{from: []string{"t"}}, nil},
|
|
||||||
{&Query{from: []string{"q"}, limit: 5, offset: 6}, nil},
|
|
||||||
{&Query{from: []string{"q"}, orderBy: []string{"a ASC", "b DESC"}}, nil},
|
|
||||||
{&Query{from: []string{"t"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
|
|
||||||
{&Query{from: []string{"a", "b"}, selectCols: []string{"count(*) as ab, thing as bd", `"stuff"`}}, nil},
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, test := range tests {
|
|
||||||
filename := filepath.Join("_fixtures", fmt.Sprintf("%02d.sql", i))
|
|
||||||
out, args := buildQuery(test.q)
|
|
||||||
|
|
||||||
if *writeGoldenFiles {
|
|
||||||
err := ioutil.WriteFile(filename, []byte(out), 0664)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to write golden file %s: %s\n", filename, err)
|
|
||||||
}
|
|
||||||
t.Logf("wrote golden file: %s\n", filename)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
byt, err := ioutil.ReadFile(filename)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Failed to read golden file %q: %v", filename, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if string(bytes.TrimSpace(byt)) != out {
|
|
||||||
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, byt, out)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !reflect.DeepEqual(args, test.args) {
|
|
||||||
t.Errorf("[%02d] Test failed:\nWant:\n%s\nGot:\n%s", i, spew.Sdump(test.args), spew.Sdump(args))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetLastWhereAsOr(t *testing.T) {
|
func TestSetLastWhereAsOr(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
q := &Query{}
|
q := &Query{}
|
||||||
|
@ -301,34 +246,34 @@ func TestInnerJoin(t *testing.T) {
|
||||||
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
||||||
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
AppendInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
||||||
|
|
||||||
if len(q.innerJoins) != 2 {
|
if len(q.joins) != 2 {
|
||||||
t.Errorf("Expected len 1, got %d", len(q.innerJoins))
|
t.Errorf("Expected len 1, got %d", len(q.joins))
|
||||||
}
|
}
|
||||||
|
|
||||||
if q.innerJoins[0].on != "thing=$1 AND stuff=$2" {
|
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
|
||||||
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
|
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
|
||||||
}
|
}
|
||||||
if q.innerJoins[1].on != "thing=$1 AND stuff=$2" {
|
if q.joins[1].clause != "thing=$1 AND stuff=$2" {
|
||||||
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
|
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(q.innerJoins[0].args) != 2 {
|
if len(q.joins[0].args) != 2 {
|
||||||
t.Errorf("Expected len 2, got %d", len(q.innerJoins[0].args))
|
t.Errorf("Expected len 2, got %d", len(q.joins[0].args))
|
||||||
}
|
}
|
||||||
if len(q.innerJoins[1].args) != 2 {
|
if len(q.joins[1].args) != 2 {
|
||||||
t.Errorf("Expected len 2, got %d", len(q.innerJoins[1].args))
|
t.Errorf("Expected len 2, got %d", len(q.joins[1].args))
|
||||||
}
|
}
|
||||||
|
|
||||||
if q.innerJoins[0].args[0] != 2 && q.innerJoins[0].args[1] != 5 {
|
if q.joins[0].args[0] != 2 && q.joins[0].args[1] != 5 {
|
||||||
t.Errorf("Invalid args values, got %#v", q.innerJoins[0].args)
|
t.Errorf("Invalid args values, got %#v", q.joins[0].args)
|
||||||
}
|
}
|
||||||
|
|
||||||
SetInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
SetInnerJoin(q, "thing=$1 AND stuff=$2", 2, 5)
|
||||||
if len(q.innerJoins) != 1 {
|
if len(q.joins) != 1 {
|
||||||
t.Errorf("Expected len 1, got %d", len(q.innerJoins))
|
t.Errorf("Expected len 1, got %d", len(q.joins))
|
||||||
}
|
}
|
||||||
|
|
||||||
if q.innerJoins[0].on != "thing=$1 AND stuff=$2" {
|
if q.joins[0].clause != "thing=$1 AND stuff=$2" {
|
||||||
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
|
t.Errorf("Got invalid innerJoin on string: %#v", q.joins)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue