Add imports library functionality

This commit is contained in:
Patrick O'brien 2017-02-05 20:17:11 +10:00
parent dea061f571
commit e2c53aa312
4 changed files with 464 additions and 196 deletions

View file

@ -43,6 +43,8 @@ type State struct {
SingletonTestTemplates *templateList SingletonTestTemplates *templateList
TestMainTemplate *template.Template TestMainTemplate *template.Template
Importer importer
} }
// New creates a new state based off of the config // New creates a new state based off of the config
@ -89,6 +91,8 @@ func New(config *Config) (*State, error) {
return nil, errors.Wrap(err, "unable to initialize struct tags") return nil, errors.Wrap(err, "unable to initialize struct tags")
} }
s.Importer = newImporter()
return s, nil return s, nil
} }

View file

@ -141,211 +141,302 @@ func removeDuplicates(dedup []string) []string {
return dedup return dedup
} }
var defaultTemplateImports = imports{ type mapImports map[string]imports
standard: importList{
`"bytes"`, type importer struct {
`"database/sql"`, standard imports
`"fmt"`, testStandard imports
`"reflect"`,
`"strings"`, singleton mapImports
`"sync"`, testSingleton mapImports
`"time"`,
}, testMain mapImports
thirdParty: importList{
`"github.com/pkg/errors"`, basedOnType mapImports
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
} }
var defaultSingletonTemplateImports = map[string]imports{ // newImporter returns an importer struct with default import values
"boil_queries": { func newImporter() importer {
thirdParty: importList{ var imp importer
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
},
},
"boil_types": {
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
},
}
var defaultTestTemplateImports = imports{ imp.standard = imports{
standard: importList{
`"bytes"`,
`"reflect"`,
`"testing"`,
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/randomize"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
}
var defaultSingletonTestTemplateImports = map[string]imports{
"boil_main_test": {
standard: importList{ standard: importList{
`"bytes"`,
`"database/sql"`, `"database/sql"`,
`"flag"`,
`"fmt"`, `"fmt"`,
`"math/rand"`, `"reflect"`,
`"os"`, `"strings"`,
`"path/filepath"`, `"sync"`,
`"testing"`,
`"time"`, `"time"`,
}, },
thirdParty: importList{ thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`, `"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`, `"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
`"github.com/vattle/sqlboiler/strmangle"`,
}, },
}, }
"boil_queries_test": {
imp.singleton = mapImports{
"boil_queries": {
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/queries"`,
`"github.com/vattle/sqlboiler/queries/qm"`,
},
},
"boil_types": {
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
},
}
imp.testStandard = imports{
standard: importList{ standard: importList{
`"bytes"`, `"bytes"`,
`"fmt"`, `"reflect"`,
`"io"`,
`"io/ioutil"`,
`"math/rand"`,
`"regexp"`,
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
},
},
"boil_suites_test": {
standard: importList{
`"testing"`, `"testing"`,
}, },
}, thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/randomize"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
}
imp.testSingleton = mapImports{
"boil_main_test": {
standard: importList{
`"database/sql"`,
`"flag"`,
`"fmt"`,
`"math/rand"`,
`"os"`,
`"path/filepath"`,
`"testing"`,
`"time"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
},
},
"boil_queries_test": {
standard: importList{
`"bytes"`,
`"fmt"`,
`"io"`,
`"io/ioutil"`,
`"math/rand"`,
`"regexp"`,
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
},
},
"boil_suites_test": {
standard: importList{
`"testing"`,
},
},
}
imp.testMain = mapImports{
"postgres": {
standard: importList{
`"bytes"`,
`"database/sql"`,
`"fmt"`,
`"io"`,
`"io/ioutil"`,
`"os"`,
`"os/exec"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`"github.com/vattle/sqlboiler/randomize"`,
`_ "github.com/lib/pq"`,
},
},
"mysql": {
standard: importList{
`"bytes"`,
`"database/sql"`,
`"fmt"`,
`"io"`,
`"io/ioutil"`,
`"os"`,
`"os/exec"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`"github.com/vattle/sqlboiler/randomize"`,
`_ "github.com/go-sql-driver/mysql"`,
},
},
}
// basedOnType imports are only included in the template output if the
// database requires one of the following special types. Check
// TranslateColumnType to see the type assignments.
imp.basedOnType = mapImports{
"null.Float32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Float64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Int64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Uint64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.String": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Bool": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Time": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.JSON": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Bytes": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"time.Time": {
standard: importList{`"time"`},
},
"types.JSON": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BytesArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Int64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Float64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BoolArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.StringArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Hstore": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
}
return imp
} }
var defaultTestMainImports = map[string]imports{ // Remove an import matching the match string under the specified key.
"postgres": { // Remove will search both standard and thirdParty import lists for a match.
standard: importList{ func (m mapImports) Remove(key string, match string) {
`"bytes"`, mp := m[key]
`"database/sql"`, for idx := 0; idx < len(mp.standard); idx++ {
`"fmt"`, if mp.standard[idx] == match {
`"io"`, mp.standard[idx] = mp.standard[len(mp.standard)-1]
`"io/ioutil"`, mp.standard = mp.standard[:len(mp.standard)-1]
`"os"`, break
`"os/exec"`, }
`"strings"`, }
}, for idx := 0; idx < len(mp.thirdParty); idx++ {
thirdParty: importList{ if mp.thirdParty[idx] == match {
`"github.com/pkg/errors"`, mp.thirdParty[idx] = mp.thirdParty[len(mp.thirdParty)-1]
`"github.com/spf13/viper"`, mp.thirdParty = mp.thirdParty[:len(mp.thirdParty)-1]
`"github.com/vattle/sqlboiler/bdb/drivers"`, break
`"github.com/vattle/sqlboiler/randomize"`, }
`_ "github.com/lib/pq"`, }
},
}, // delete the key and return if both import lists are empty
"mysql": { if len(mp.thirdParty) == 0 && len(mp.standard) == 0 {
standard: importList{ delete(m, key)
`"bytes"`, return
`"database/sql"`, }
`"fmt"`,
`"io"`, m[key] = mp
`"io/ioutil"`,
`"os"`,
`"os/exec"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`"github.com/vattle/sqlboiler/randomize"`,
`_ "github.com/go-sql-driver/mysql"`,
},
},
} }
// importsBasedOnType imports are only included in the template output if the // Add an import under the specified key. If the key does not exist, it
// database requires one of the following special types. Check // will be created.
// TranslateColumnType to see the type assignments. func (m mapImports) Add(key string, value string, thirdParty bool) {
var importsBasedOnType = map[string]imports{ mp := m[key]
"null.Float32": { if thirdParty {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, mp.thirdParty = append(mp.thirdParty, value)
}, } else {
"null.Float64": { mp.standard = append(mp.standard, value)
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }
},
"null.Int": { m[key] = mp
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }
},
"null.Int8": { // Remove an import matching the match string under the specified key.
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, // Remove will search both standard and thirdParty import lists for a match.
}, func (i *imports) Remove(match string) {
"null.Int16": { for idx := 0; idx < len(i.standard); idx++ {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, if i.standard[idx] == match {
}, i.standard[idx] = i.standard[len(i.standard)-1]
"null.Int32": { i.standard = i.standard[:len(i.standard)-1]
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, break
}, }
"null.Int64": { }
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, for idx := 0; idx < len(i.thirdParty); idx++ {
}, if i.thirdParty[idx] == match {
"null.Uint": { i.thirdParty[idx] = i.thirdParty[len(i.thirdParty)-1]
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, i.thirdParty = i.thirdParty[:len(i.thirdParty)-1]
}, break
"null.Uint8": { }
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, }
}, }
"null.Uint16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, // Add an import under the specified key. If the key does not exist, it
}, // will be created.
"null.Uint32": { func (i *imports) Add(value string, thirdParty bool) {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, if thirdParty {
}, i.thirdParty = append(i.thirdParty, value)
"null.Uint64": { } else {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`}, i.standard = append(i.standard, value)
}, }
"null.String": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Bool": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Time": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.JSON": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"null.Bytes": {
thirdParty: importList{`"gopkg.in/nullbio/null.v6"`},
},
"time.Time": {
standard: importList{`"time"`},
},
"types.JSON": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BytesArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Int64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Float64Array": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.BoolArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.StringArray": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
"types.Hstore": {
thirdParty: importList{`"github.com/vattle/sqlboiler/types"`},
},
} }

View file

@ -54,6 +54,177 @@ func TestImportsSort(t *testing.T) {
} }
} }
func TestImportsAddAndRemove(t *testing.T) {
t.Parallel()
var imp imports
imp.Add("value", false)
if len(imp.standard) != 1 {
t.Errorf("expected len 1, got %d", len(imp.standard))
}
if imp.standard[0] != "value" {
t.Errorf("expected %q to be added", "value")
}
imp.Add("value2", true)
if len(imp.thirdParty) != 1 {
t.Errorf("expected len 1, got %d", len(imp.thirdParty))
}
if imp.thirdParty[0] != "value2" {
t.Errorf("expected %q to be added", "value2")
}
imp.Remove("value")
if len(imp.standard) != 0 {
t.Errorf("expected len 0, got %d", len(imp.standard))
}
imp.Remove("value")
if len(imp.standard) != 0 {
t.Errorf("expected len 0, got %d", len(imp.standard))
}
imp.Remove("value2")
if len(imp.thirdParty) != 0 {
t.Errorf("expected len 0, got %d", len(imp.thirdParty))
}
// Test deleting last element in len 2 slice
imp.Add("value3", false)
imp.Add("value4", false)
if len(imp.standard) != 2 {
t.Errorf("expected len 2, got %d", len(imp.standard))
}
imp.Remove("value4")
if len(imp.standard) != 1 {
t.Errorf("expected len 1, got %d", len(imp.standard))
}
if imp.standard[0] != "value3" {
t.Errorf("expected %q, got %q", "value3", imp.standard[0])
}
// Test deleting first element in len 2 slice
imp.Add("value4", false)
imp.Remove("value3")
if len(imp.standard) != 1 {
t.Errorf("expected len 1, got %d", len(imp.standard))
}
if imp.standard[0] != "value4" {
t.Errorf("expected %q, got %q", "value4", imp.standard[0])
}
imp.Remove("value2")
if len(imp.thirdParty) != 0 {
t.Errorf("expected len 0, got %d", len(imp.thirdParty))
}
// Test deleting last element in len 2 slice
imp.Add("value5", true)
imp.Add("value6", true)
if len(imp.thirdParty) != 2 {
t.Errorf("expected len 2, got %d", len(imp.thirdParty))
}
imp.Remove("value6")
if len(imp.thirdParty) != 1 {
t.Errorf("expected len 1, got %d", len(imp.thirdParty))
}
if imp.thirdParty[0] != "value5" {
t.Errorf("expected %q, got %q", "value5", imp.thirdParty[0])
}
// Test deleting first element in len 2 slice
imp.Add("value6", true)
imp.Remove("value5")
if len(imp.thirdParty) != 1 {
t.Errorf("expected len 1, got %d", len(imp.thirdParty))
}
if imp.thirdParty[0] != "value6" {
t.Errorf("expected %q, got %q", "value6", imp.thirdParty[0])
}
}
func TestMapImportsAddAndRemove(t *testing.T) {
t.Parallel()
imp := mapImports{}
imp.Add("cat", "value", false)
if len(imp["cat"].standard) != 1 {
t.Errorf("expected len 1, got %d", len(imp["cat"].standard))
}
if imp["cat"].standard[0] != "value" {
t.Errorf("expected %q to be added", "value")
}
imp.Add("cat", "value2", true)
if len(imp["cat"].thirdParty) != 1 {
t.Errorf("expected len 1, got %d", len(imp["cat"].thirdParty))
}
if imp["cat"].thirdParty[0] != "value2" {
t.Errorf("expected %q to be added", "value2")
}
imp.Remove("cat", "value")
if len(imp["cat"].standard) != 0 {
t.Errorf("expected len 0, got %d", len(imp["cat"].standard))
}
imp.Remove("cat", "value")
if len(imp["cat"].standard) != 0 {
t.Errorf("expected len 0, got %d", len(imp["cat"].standard))
}
imp.Remove("cat", "value2")
if len(imp["cat"].thirdParty) != 0 {
t.Errorf("expected len 0, got %d", len(imp["cat"].thirdParty))
}
// If there are no elements left in key, test key is deleted
_, ok := imp["cat"]
if ok {
t.Errorf("expected cat key to be deleted when list empty")
}
// Test deleting last element in len 2 slice
imp.Add("cat", "value3", false)
imp.Add("cat", "value4", false)
if len(imp["cat"].standard) != 2 {
t.Errorf("expected len 2, got %d", len(imp["cat"].standard))
}
imp.Remove("cat", "value4")
if len(imp["cat"].standard) != 1 {
t.Errorf("expected len 1, got %d", len(imp["cat"].standard))
}
if imp["cat"].standard[0] != "value3" {
t.Errorf("expected %q, got %q", "value3", imp["cat"].standard[0])
}
// Test deleting first element in len 2 slice
imp.Add("cat", "value4", false)
imp.Remove("cat", "value3")
if len(imp["cat"].standard) != 1 {
t.Errorf("expected len 1, got %d", len(imp["cat"].standard))
}
if imp["cat"].standard[0] != "value4" {
t.Errorf("expected %q, got %q", "value4", imp["cat"].standard[0])
}
imp.Remove("cat", "value2")
if len(imp["cat"].thirdParty) != 0 {
t.Errorf("expected len 0, got %d", len(imp["cat"].thirdParty))
}
// Test deleting last element in len 2 slice
imp.Add("dog", "value5", true)
imp.Add("dog", "value6", true)
if len(imp["dog"].thirdParty) != 2 {
t.Errorf("expected len 2, got %d", len(imp["dog"].thirdParty))
}
imp.Remove("dog", "value6")
if len(imp["dog"].thirdParty) != 1 {
t.Errorf("expected len 1, got %d", len(imp["dog"].thirdParty))
}
if imp["dog"].thirdParty[0] != "value5" {
t.Errorf("expected %q, got %q", "value5", imp["dog"].thirdParty[0])
}
// Test deleting first element in len 2 slice
imp.Add("dog", "value6", true)
imp.Remove("dog", "value5")
if len(imp["dog"].thirdParty) != 1 {
t.Errorf("expected len 1, got %d", len(imp["dog"].thirdParty))
}
if imp["dog"].thirdParty[0] != "value6" {
t.Errorf("expected %q, got %q", "value6", imp["dog"].thirdParty[0])
}
}
func TestCombineTypeImports(t *testing.T) { func TestCombineTypeImports(t *testing.T) {
t.Parallel() t.Parallel()
@ -94,7 +265,9 @@ func TestCombineTypeImports(t *testing.T) {
}, },
} }
res1 := combineTypeImports(imports1, importsBasedOnType, cols) imps := newImporter()
res1 := combineTypeImports(imports1, imps.basedOnType, cols)
if !reflect.DeepEqual(res1, importsExpected) { if !reflect.DeepEqual(res1, importsExpected) {
t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1) t.Errorf("Expected res1 to match importsExpected, got:\n\n%#v\n", res1)
@ -112,7 +285,7 @@ func TestCombineTypeImports(t *testing.T) {
}, },
} }
res2 := combineTypeImports(imports2, importsBasedOnType, cols) res2 := combineTypeImports(imports2, imps.basedOnType, cols)
if !reflect.DeepEqual(res2, importsExpected) { if !reflect.DeepEqual(res2, importsExpected) {
t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1) t.Errorf("Expected res2 to match importsExpected, got:\n\n%#v\n", res1)

View file

@ -38,7 +38,7 @@ func generateOutput(state *State, data *templateData) error {
state: state, state: state,
data: data, data: data,
templates: state.Templates, templates: state.Templates,
importSet: defaultTemplateImports, importSet: state.Importer.standard,
combineImportsOnType: true, combineImportsOnType: true,
fileSuffix: ".go", fileSuffix: ".go",
}) })
@ -50,7 +50,7 @@ func generateTestOutput(state *State, data *templateData) error {
state: state, state: state,
data: data, data: data,
templates: state.TestTemplates, templates: state.TestTemplates,
importSet: defaultTestTemplateImports, importSet: state.Importer.testStandard,
combineImportsOnType: false, combineImportsOnType: false,
fileSuffix: "_test.go", fileSuffix: "_test.go",
}) })
@ -63,7 +63,7 @@ func generateSingletonOutput(state *State, data *templateData) error {
state: state, state: state,
data: data, data: data,
templates: state.SingletonTemplates, templates: state.SingletonTemplates,
importNamedSet: defaultSingletonTemplateImports, importNamedSet: state.Importer.singleton,
fileSuffix: ".go", fileSuffix: ".go",
}) })
} }
@ -75,7 +75,7 @@ func generateSingletonTestOutput(state *State, data *templateData) error {
state: state, state: state,
data: data, data: data,
templates: state.SingletonTestTemplates, templates: state.SingletonTestTemplates,
importNamedSet: defaultSingletonTestTemplateImports, importNamedSet: state.Importer.testSingleton,
fileSuffix: ".go", fileSuffix: ".go",
}) })
} }
@ -106,7 +106,7 @@ func executeTemplates(e executeTemplateData) error {
imps.standard = e.importSet.standard imps.standard = e.importSet.standard
imps.thirdParty = e.importSet.thirdParty imps.thirdParty = e.importSet.thirdParty
if e.combineImportsOnType { if e.combineImportsOnType {
imps = combineTypeImports(imps, importsBasedOnType, e.data.Table.Columns) imps = combineTypeImports(imps, e.state.Importer.basedOnType, e.data.Table.Columns)
} }
writeFileDisclaimer(out) writeFileDisclaimer(out)
@ -170,8 +170,8 @@ func generateTestMainOutput(state *State, data *templateData) error {
out.Reset() out.Reset()
var imps imports var imps imports
imps.standard = defaultTestMainImports[state.Config.DriverName].standard imps.standard = state.Importer.testMain[state.Config.DriverName].standard
imps.thirdParty = defaultTestMainImports[state.Config.DriverName].thirdParty imps.thirdParty = state.Importer.testMain[state.Config.DriverName].thirdParty
writeFileDisclaimer(out) writeFileDisclaimer(out)
writePackageName(out, state.Config.PkgName) writePackageName(out, state.Config.PkgName)