Add parsing routine for SQL from/join statements

- Bonus: Add yaml/toml struct tags for models
This commit is contained in:
Aaron L 2016-08-05 23:51:13 -07:00
parent 2d608c4e17
commit 307fe4919a
3 changed files with 132 additions and 1 deletions

View file

@ -4,6 +4,7 @@ import (
"bytes"
"database/sql"
"fmt"
"regexp"
"strings"
"github.com/nullbio/sqlboiler/strmangle"
@ -290,3 +291,77 @@ func whereClause(q *Query) (string, []interface{}) {
return buf.String(), args
}
var (
rgxIdentifier = regexp.MustCompile(`^(?i)"?[a-z_][_a-z0-9]*"?(?:\."?[_a-z][_a-z0-9]*"?)*$`)
rgxJoinIdentifiers = regexp.MustCompile(`^(?i)(?:join|inner|natural|outer|left|right)$`)
)
// identifierMapping creates a map of all identifiers to potential model names
func identifierMapping(q *Query) map[string]string {
var ids map[string]string
setID := func(alias, name string) {
if ids == nil {
ids = make(map[string]string)
}
ids[alias] = name
}
for _, from := range q.from {
tokens := strings.Split(from, " ")
parseIdentifierClause(tokens, setID)
}
for _, join := range q.innerJoins {
tokens := strings.Split(join.on, " ")
discard := 0
for rgxJoinIdentifiers.MatchString(tokens[discard]) {
discard++
}
parseIdentifierClause(tokens[discard:], setID)
}
return ids
}
// parseBits takes a set of tokens and looks for something of the form:
// a b
// a as b
// where 'a' and 'b' are valid SQL identifiers
// It only evaluates the first 3 tokens (anything past that is superfluous)
// It stops parsing when it finds "on" or an invalid identifier
func parseIdentifierClause(tokens []string, setID func(string, string)) {
var name, alias string
sawIdent, sawAs := false, false
if len(tokens) > 3 {
tokens = tokens[:3]
}
for _, tok := range tokens {
if t := strings.ToLower(tok); sawIdent && t == "as" {
sawAs = true
continue
} else if sawIdent && t == "on" {
break
}
if !rgxIdentifier.MatchString(tok) {
break
}
if sawIdent || sawAs {
alias = strings.Trim(tok, `"`)
break
}
name = strings.Trim(tok, `"`)
sawIdent = true
}
if len(alias) > 0 {
setID(alias, name)
} else {
setID(name, name)
}
}

View file

@ -332,3 +332,59 @@ func TestInnerJoin(t *testing.T) {
t.Errorf("Got invalid innerJoin on string: %#v", q.innerJoins)
}
}
func TestIdentifierMapping(t *testing.T) {
t.Parallel()
tests := []struct {
In Query
Out map[string]string
}{
{
In: Query{from: []string{`a`}},
Out: map[string]string{"a": "a"},
},
{
In: Query{from: []string{`"a"`, `b`}},
Out: map[string]string{"a": "a", "b": "b"},
},
{
In: Query{from: []string{`a as b`}},
Out: map[string]string{"b": "a"},
},
{
In: Query{from: []string{`a AS "b"`, `"c" as d`}},
Out: map[string]string{"b": "a", "d": "c"},
},
{
In: Query{innerJoins: []join{{on: `inner join a on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `outer join "a" on stuff = there`}}},
Out: map[string]string{"a": "a"},
},
{
In: Query{innerJoins: []join{{on: `natural join a as b on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
{
In: Query{innerJoins: []join{{on: `right outer join "a" as "b" on stuff = there`}}},
Out: map[string]string{"b": "a"},
},
}
for i, test := range tests {
m := identifierMapping(&test.In)
for k, v := range test.Out {
val, ok := m[k]
if !ok {
t.Errorf("%d) want: %s = %s, but was missing", i, k, v)
}
if val != v {
t.Errorf("%d) want: %s = %s, got: %s", i, k, v, val)
}
}
}
}

View file

@ -3,6 +3,6 @@
// {{$modelName}} is an object representing the database table.
type {{$modelName}} struct {
{{range $column := .Table.Columns -}}
{{titleCase $column.Name}} {{$column.Type}} `boil:"{{printf "%s.%s" $tableNameSingular $column.Name}}" json:"{{$column.Name}}"`
{{titleCase $column.Name}} {{$column.Type}} `boil:"{{printf "%s.%s" $tableNameSingular $column.Name}}" json:"{{$column.Name}}" toml:"{{$column.Name}}" yaml:"{{$column.Name}}"`
{{end -}}
}