Add test for a query with writeStars
This commit is contained in:
parent
6596868cb8
commit
1ba2333658
4 changed files with 137 additions and 19 deletions
1
boil/_fixtures/05.sql
Normal file
1
boil/_fixtures/05.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "a"."happy" as "a.happy","r"."fun" as "r.fun","q" FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id;
|
1
boil/_fixtures/06.sql
Normal file
1
boil/_fixtures/06.sql
Normal file
|
@ -0,0 +1 @@
|
|||
SELECT "a".* FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id;
|
|
@ -11,7 +11,6 @@ import (
|
|||
|
||||
var (
|
||||
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
|
||||
rgxJoinIdentifiers = regexp.MustCompile(`^(?i)(?:join|inner|natural|outer|left|right)$`)
|
||||
)
|
||||
|
||||
func buildQuery(q *Query) (string, []interface{}) {
|
||||
|
@ -45,7 +44,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
|||
|
||||
hasSelectCols := len(q.selectCols) != 0
|
||||
hasJoins := len(q.joins) != 0
|
||||
if hasJoins && !hasModFunc {
|
||||
if hasSelectCols && hasJoins && !hasModFunc {
|
||||
selectColsWithAs := writeAsStatements(q)
|
||||
// Don't identQuoteSlice - writeAsStatements does this
|
||||
buf.WriteString(strings.Join(selectColsWithAs, `,`))
|
||||
|
@ -90,13 +89,51 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
|||
return buf, args
|
||||
}
|
||||
|
||||
func writeComplexSelect(q *Query, buf *bytes.Buffer) {
|
||||
func writeStars(q *Query) []string {
|
||||
cols := make([]string, 0, len(q.from))
|
||||
for _, f := range q.from {
|
||||
toks := strings.Split(f, " ")
|
||||
if len(toks) == 1 {
|
||||
cols = append(cols, fmt.Sprintf(`%s.*`, strmangle.IdentQuote(toks[0])))
|
||||
continue
|
||||
}
|
||||
|
||||
alias, name, ok := parseFromClause(toks)
|
||||
if !ok {
|
||||
}
|
||||
|
||||
if len(alias) != 0 {
|
||||
name = alias
|
||||
}
|
||||
cols = append(cols, fmt.Sprintf(`%s.*`, strmangle.IdentQuote(name)))
|
||||
}
|
||||
|
||||
return cols
|
||||
}
|
||||
|
||||
func writeAsStatements(q *Query) []string {
|
||||
cols := make([]string, len(q.selectCols))
|
||||
for _, col := range q.selectCols {
|
||||
if !rgxIdentifier.Match {
|
||||
cols
|
||||
for i, col := range q.selectCols {
|
||||
if !rgxIdentifier.MatchString(col) {
|
||||
cols[i] = col
|
||||
continue
|
||||
}
|
||||
|
||||
toks := strings.Split(col, ".")
|
||||
if len(toks) == 1 {
|
||||
cols[i] = strmangle.IdentQuote(col)
|
||||
continue
|
||||
}
|
||||
|
||||
asParts := make([]string, len(toks))
|
||||
for j, tok := range toks {
|
||||
asParts[j] = strings.Trim(tok, `"`)
|
||||
}
|
||||
|
||||
cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(col), strings.Join(asParts, "."))
|
||||
}
|
||||
|
||||
return cols
|
||||
}
|
||||
|
||||
func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) {
|
||||
|
@ -175,14 +212,25 @@ func identifierMapping(q *Query) map[string]string {
|
|||
// It only evaluates the first 3 tokens (anything past that is superfluous)
|
||||
// It stops parsing when it finds "on" or an invalid identifier
|
||||
func parseIdentifierClause(tokens []string, setID func(string, string)) {
|
||||
var name, alias string
|
||||
sawIdent, sawAs := false, false
|
||||
|
||||
if len(tokens) > 3 {
|
||||
tokens = tokens[:3]
|
||||
alias, name, ok := parseFromClause(tokens)
|
||||
if !ok {
|
||||
panic("could not parse from statement")
|
||||
}
|
||||
|
||||
for _, tok := range tokens {
|
||||
if len(alias) > 0 {
|
||||
setID(alias, name)
|
||||
} else {
|
||||
setID(name, name)
|
||||
}
|
||||
}
|
||||
|
||||
func parseFromClause(toks []string) (alias, name string, ok bool) {
|
||||
if len(toks) > 3 {
|
||||
toks = toks[:3]
|
||||
}
|
||||
|
||||
sawIdent, sawAs := false, false
|
||||
for _, tok := range toks {
|
||||
if t := strings.ToLower(tok); sawIdent && t == "as" {
|
||||
sawAs = true
|
||||
continue
|
||||
|
@ -201,11 +249,8 @@ func parseIdentifierClause(tokens []string, setID func(string, string)) {
|
|||
|
||||
name = strings.Trim(tok, `"`)
|
||||
sawIdent = true
|
||||
ok = true
|
||||
}
|
||||
|
||||
if len(alias) > 0 {
|
||||
setID(alias, name)
|
||||
} else {
|
||||
setID(name, name)
|
||||
}
|
||||
return alias, name, ok
|
||||
}
|
||||
|
|
|
@ -35,6 +35,10 @@ func TestBuildQuery(t *testing.T) {
|
|||
from: []string{"happiness as a"},
|
||||
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
|
||||
}, nil},
|
||||
{&Query{
|
||||
from: []string{"happiness as a"},
|
||||
joins: []join{{clause: "rainbows r on a.id = r.happy_id"}},
|
||||
}, nil},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
|
@ -120,3 +124,70 @@ func TestIdentifierMapping(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteStars(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
In Query
|
||||
Out []string
|
||||
}{
|
||||
{
|
||||
In: Query{from: []string{`a`}},
|
||||
Out: []string{`"a".*`},
|
||||
},
|
||||
{
|
||||
In: Query{from: []string{`a as b`}},
|
||||
Out: []string{`"b".*`},
|
||||
},
|
||||
{
|
||||
In: Query{from: []string{`a as b`, `c`}},
|
||||
Out: []string{`"b".*`, `"c".*`},
|
||||
},
|
||||
{
|
||||
In: Query{from: []string{`a as b`, `c as d`}},
|
||||
Out: []string{`"b".*`, `"d".*`},
|
||||
},
|
||||
}
|
||||
|
||||
for i, test := range tests {
|
||||
selects := writeStars(&test.In)
|
||||
if !reflect.DeepEqual(selects, test.Out) {
|
||||
t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteAsStatements(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
query := Query{
|
||||
selectCols: []string{
|
||||
`a`,
|
||||
`a.fun`,
|
||||
`"b"."fun"`,
|
||||
`"b".fun`,
|
||||
`b."fun"`,
|
||||
`a.clown.run`,
|
||||
`COUNT(a)`,
|
||||
},
|
||||
}
|
||||
|
||||
expect := []string{
|
||||
`"a"`,
|
||||
`"a"."fun" as "a.fun"`,
|
||||
`"b"."fun" as "b.fun"`,
|
||||
`"b"."fun" as "b.fun"`,
|
||||
`"b"."fun" as "b.fun"`,
|
||||
`"a"."clown"."run" as "a.clown.run"`,
|
||||
`COUNT(a)`,
|
||||
}
|
||||
|
||||
gots := writeAsStatements(&query)
|
||||
|
||||
for i, got := range gots {
|
||||
if expect[i] != got {
|
||||
t.Errorf(`%d) want: %s, got: %s`, i, expect[i], got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue