diff --git a/boil/_fixtures/05.sql b/boil/_fixtures/05.sql new file mode 100644 index 0000000..8adf761 --- /dev/null +++ b/boil/_fixtures/05.sql @@ -0,0 +1 @@ +SELECT "a"."happy" as "a.happy","r"."fun" as "r.fun","q" FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id; \ No newline at end of file diff --git a/boil/_fixtures/06.sql b/boil/_fixtures/06.sql new file mode 100644 index 0000000..99ff89f --- /dev/null +++ b/boil/_fixtures/06.sql @@ -0,0 +1 @@ +SELECT "a".* FROM happiness as a INNER JOIN rainbows r on a.id = r.happy_id; \ No newline at end of file diff --git a/boil/query_builders.go b/boil/query_builders.go index 0f1c729..1f90b4f 100644 --- a/boil/query_builders.go +++ b/boil/query_builders.go @@ -10,8 +10,7 @@ import ( ) var ( - rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`) - rgxJoinIdentifiers = regexp.MustCompile(`^(?i)(?:join|inner|natural|outer|left|right)$`) + rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`) ) func buildQuery(q *Query) (string, []interface{}) { @@ -45,7 +44,7 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { hasSelectCols := len(q.selectCols) != 0 hasJoins := len(q.joins) != 0 - if hasJoins && !hasModFunc { + if hasSelectCols && hasJoins && !hasModFunc { selectColsWithAs := writeAsStatements(q) // Don't identQuoteSlice - writeAsStatements does this buf.WriteString(strings.Join(selectColsWithAs, `,`)) @@ -90,13 +89,51 @@ func buildSelectQuery(q *Query) (*bytes.Buffer, []interface{}) { return buf, args } -func writeComplexSelect(q *Query, buf *bytes.Buffer) { - cols := make([]string, len(q.selectCols)) - for _, col := range q.selectCols { - if !rgxIdentifier.Match { - cols +func writeStars(q *Query) []string { + cols := make([]string, 0, len(q.from)) + for _, f := range q.from { + toks := strings.Split(f, " ") + if len(toks) == 1 { + cols = append(cols, fmt.Sprintf(`%s.*`, strmangle.IdentQuote(toks[0]))) + continue } + + alias, name, ok := parseFromClause(toks) + if !ok { + } + + if len(alias) != 0 { + name = alias + } + cols = append(cols, fmt.Sprintf(`%s.*`, strmangle.IdentQuote(name))) } + + return cols +} + +func writeAsStatements(q *Query) []string { + cols := make([]string, len(q.selectCols)) + for i, col := range q.selectCols { + if !rgxIdentifier.MatchString(col) { + cols[i] = col + continue + } + + toks := strings.Split(col, ".") + if len(toks) == 1 { + cols[i] = strmangle.IdentQuote(col) + continue + } + + asParts := make([]string, len(toks)) + for j, tok := range toks { + asParts[j] = strings.Trim(tok, `"`) + } + + cols[i] = fmt.Sprintf(`%s as "%s"`, strmangle.IdentQuote(col), strings.Join(asParts, ".")) + } + + return cols } func buildDeleteQuery(q *Query) (*bytes.Buffer, []interface{}) { @@ -175,14 +212,25 @@ func identifierMapping(q *Query) map[string]string { // It only evaluates the first 3 tokens (anything past that is superfluous) // It stops parsing when it finds "on" or an invalid identifier func parseIdentifierClause(tokens []string, setID func(string, string)) { - var name, alias string - sawIdent, sawAs := false, false - - if len(tokens) > 3 { - tokens = tokens[:3] + alias, name, ok := parseFromClause(tokens) + if !ok { + panic("could not parse from statement") } - for _, tok := range tokens { + if len(alias) > 0 { + setID(alias, name) + } else { + setID(name, name) + } +} + +func parseFromClause(toks []string) (alias, name string, ok bool) { + if len(toks) > 3 { + toks = toks[:3] + } + + sawIdent, sawAs := false, false + for _, tok := range toks { if t := strings.ToLower(tok); sawIdent && t == "as" { sawAs = true continue @@ -201,11 +249,8 @@ func parseIdentifierClause(tokens []string, setID func(string, string)) { name = strings.Trim(tok, `"`) sawIdent = true + ok = true } - if len(alias) > 0 { - setID(alias, name) - } else { - setID(name, name) - } + return alias, name, ok } diff --git a/boil/query_builders_test.go b/boil/query_builders_test.go index 43955b6..f6f2847 100644 --- a/boil/query_builders_test.go +++ b/boil/query_builders_test.go @@ -35,6 +35,10 @@ func TestBuildQuery(t *testing.T) { from: []string{"happiness as a"}, joins: []join{{clause: "rainbows r on a.id = r.happy_id"}}, }, nil}, + {&Query{ + from: []string{"happiness as a"}, + joins: []join{{clause: "rainbows r on a.id = r.happy_id"}}, + }, nil}, } for i, test := range tests { @@ -120,3 +124,70 @@ func TestIdentifierMapping(t *testing.T) { } } } + +func TestWriteStars(t *testing.T) { + t.Parallel() + + tests := []struct { + In Query + Out []string + }{ + { + In: Query{from: []string{`a`}}, + Out: []string{`"a".*`}, + }, + { + In: Query{from: []string{`a as b`}}, + Out: []string{`"b".*`}, + }, + { + In: Query{from: []string{`a as b`, `c`}}, + Out: []string{`"b".*`, `"c".*`}, + }, + { + In: Query{from: []string{`a as b`, `c as d`}}, + Out: []string{`"b".*`, `"d".*`}, + }, + } + + for i, test := range tests { + selects := writeStars(&test.In) + if !reflect.DeepEqual(selects, test.Out) { + t.Errorf("writeStar test fail %d\nwant: %v\ngot: %v", i, test.Out, selects) + } + } +} + +func TestWriteAsStatements(t *testing.T) { + t.Parallel() + + query := Query{ + selectCols: []string{ + `a`, + `a.fun`, + `"b"."fun"`, + `"b".fun`, + `b."fun"`, + `a.clown.run`, + `COUNT(a)`, + }, + } + + expect := []string{ + `"a"`, + `"a"."fun" as "a.fun"`, + `"b"."fun" as "b.fun"`, + `"b"."fun" as "b.fun"`, + `"b"."fun" as "b.fun"`, + `"a"."clown"."run" as "a.clown.run"`, + `COUNT(a)`, + } + + gots := writeAsStatements(&query) + + for i, got := range gots { + if expect[i] != got { + t.Errorf(`%d) want: %s, got: %s`, i, expect[i], got) + } + } +}