sqlboiler/imports.go
2016-08-18 00:06:28 -07:00

298 lines
6.1 KiB
Go

package main
import (
"bytes"
"fmt"
"sort"
"strings"
"github.com/vattle/sqlboiler/bdb"
)
// imports defines the optional standard imports and
// thirdParty imports (from github for example)
type imports struct {
standard importList
thirdParty importList
}
// importList is a list of import names
type importList []string
func (i importList) Len() int {
return len(i)
}
func (i importList) Swap(k, j int) {
i[k], i[j] = i[j], i[k]
}
func (i importList) Less(k, j int) bool {
res := strings.Compare(strings.TrimLeft(i[k], "_ "), strings.TrimLeft(i[j], "_ "))
if res <= 0 {
return true
}
return false
}
func combineImports(a, b imports) imports {
var c imports
c.standard = removeDuplicates(combineStringSlices(a.standard, b.standard))
c.thirdParty = removeDuplicates(combineStringSlices(a.thirdParty, b.thirdParty))
sort.Sort(c.standard)
sort.Sort(c.thirdParty)
return c
}
func combineTypeImports(a imports, b map[string]imports, columns []bdb.Column) imports {
tmpImp := imports{
standard: make(importList, len(a.standard)),
thirdParty: make(importList, len(a.thirdParty)),
}
copy(tmpImp.standard, a.standard)
copy(tmpImp.thirdParty, a.thirdParty)
for _, col := range columns {
for key, imp := range b {
if col.Type == key {
tmpImp.standard = append(tmpImp.standard, imp.standard...)
tmpImp.thirdParty = append(tmpImp.thirdParty, imp.thirdParty...)
}
}
}
tmpImp.standard = removeDuplicates(tmpImp.standard)
tmpImp.thirdParty = removeDuplicates(tmpImp.thirdParty)
sort.Sort(tmpImp.standard)
sort.Sort(tmpImp.thirdParty)
return tmpImp
}
func buildImportString(imps imports) []byte {
stdlen, thirdlen := len(imps.standard), len(imps.thirdParty)
if stdlen+thirdlen < 1 {
return []byte{}
}
if stdlen+thirdlen == 1 {
var imp string
if stdlen == 1 {
imp = imps.standard[0]
} else {
imp = imps.thirdParty[0]
}
return []byte(fmt.Sprintf("import %s", imp))
}
buf := &bytes.Buffer{}
buf.WriteString("import (")
for _, std := range imps.standard {
fmt.Fprintf(buf, "\n\t%s", std)
}
if stdlen != 0 && thirdlen != 0 {
buf.WriteString("\n")
}
for _, third := range imps.thirdParty {
fmt.Fprintf(buf, "\n\t%s", third)
}
buf.WriteString("\n)\n")
return buf.Bytes()
}
func combineStringSlices(a, b []string) []string {
c := make([]string, len(a)+len(b))
if len(a) > 0 {
copy(c, a)
}
if len(b) > 0 {
copy(c[len(a):], b)
}
return c
}
func removeDuplicates(dedup []string) []string {
if len(dedup) <= 1 {
return dedup
}
for i := 0; i < len(dedup)-1; i++ {
for j := i + 1; j < len(dedup); j++ {
if dedup[i] != dedup[j] {
continue
}
if j != len(dedup)-1 {
dedup[j] = dedup[len(dedup)-1]
j--
}
dedup = dedup[:len(dedup)-1]
}
}
return dedup
}
var defaultTemplateImports = imports{
standard: importList{
`"fmt"`,
`"strings"`,
`"database/sql"`,
},
thirdParty: importList{
`"github.com/pkg/errors"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/qm"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
}
var defaultSingletonTemplateImports = map[string]imports{
"boil_queries": {
standard: importList{
`"fmt"`,
`"strings"`,
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/qm"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
},
"boil_types": {
thirdParty: importList{
`"github.com/pkg/errors"`,
},
},
}
var defaultTestTemplateImports = imports{
standard: importList{
`"testing"`,
`"reflect"`,
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/boil/randomize"`,
`"github.com/vattle/sqlboiler/strmangle"`,
},
}
var defaultSingletonTestTemplateImports = map[string]imports{
"boil_viper_test": {
standard: importList{
`"database/sql"`,
`"os"`,
`"path/filepath"`,
},
thirdParty: importList{
`"github.com/spf13/viper"`,
},
},
"boil_queries_test": {
standard: importList{
`"crypto/md5"`,
`"fmt"`,
`"os"`,
`"strconv"`,
`"math/rand"`,
},
thirdParty: importList{
`"github.com/vattle/sqlboiler/boil"`,
},
},
"boil_suites_test": {
standard: importList{
`"testing"`,
},
},
}
var defaultTestMainImports = map[string]imports{
"postgres": {
standard: importList{
`"testing"`,
`"os"`,
`"os/exec"`,
`"flag"`,
`"fmt"`,
`"io/ioutil"`,
`"bytes"`,
`"database/sql"`,
`"path/filepath"`,
`"time"`,
`"math/rand"`,
},
thirdParty: importList{
`"github.com/kat-co/vala"`,
`"github.com/pkg/errors"`,
`"github.com/spf13/viper"`,
`"github.com/vattle/sqlboiler/boil"`,
`"github.com/vattle/sqlboiler/bdb/drivers"`,
`_ "github.com/lib/pq"`,
},
},
}
// importsBasedOnType imports are only included in the template output if the
// database requires one of the following special types. Check
// TranslateColumnType to see the type assignments.
var importsBasedOnType = map[string]imports{
"null.Float32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Float64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Int64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint8": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint16": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint32": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Uint64": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.String": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Bool": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"null.Time": {
thirdParty: importList{`"gopkg.in/nullbio/null.v4"`},
},
"time.Time": {
standard: importList{`"time"`},
},
}