2016-02-23 09:27:32 +01:00
|
|
|
package cmds
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
2016-03-02 06:37:14 +01:00
|
|
|
"io"
|
2016-02-23 09:27:32 +01:00
|
|
|
"os"
|
2016-02-23 13:38:24 +01:00
|
|
|
|
|
|
|
"github.com/pobri19/sqlboiler/dbdrivers"
|
2016-02-29 10:39:49 +01:00
|
|
|
"github.com/spf13/cobra"
|
2016-02-23 09:27:32 +01:00
|
|
|
)
|
|
|
|
|
2016-02-29 10:39:49 +01:00
|
|
|
// CobraRunFunc declares the cobra.Command.Run function definition
|
|
|
|
type CobraRunFunc func(cmd *cobra.Command, args []string)
|
|
|
|
|
|
|
|
// CmdData holds the table schema a slice of (column name, column type) slices.
|
|
|
|
// It also holds a slice of all of the table names sqlboiler is generating against,
|
|
|
|
// the database driver chosen by the driver flag at runtime, and a pointer to the
|
|
|
|
// output file, if one is specified with a flag.
|
|
|
|
type CmdData struct {
|
2016-03-02 05:05:25 +01:00
|
|
|
Tables []string
|
2016-03-23 05:25:57 +01:00
|
|
|
Columns [][]dbdrivers.Column
|
2016-03-02 05:05:25 +01:00
|
|
|
PkgName string
|
|
|
|
OutFolder string
|
2016-03-23 05:25:57 +01:00
|
|
|
Interface dbdrivers.Interface
|
2016-02-29 10:39:49 +01:00
|
|
|
}
|
|
|
|
|
2016-02-24 09:53:34 +01:00
|
|
|
// tplData is used to pass data to the template
|
2016-02-23 13:38:24 +01:00
|
|
|
type tplData struct {
|
2016-03-02 05:05:25 +01:00
|
|
|
Table string
|
2016-03-23 05:25:57 +01:00
|
|
|
Columns []dbdrivers.Column
|
2016-03-16 15:33:58 +01:00
|
|
|
PkgName string
|
2016-02-23 13:38:24 +01:00
|
|
|
}
|
|
|
|
|
2016-02-24 09:53:34 +01:00
|
|
|
// errorQuit displays an error message and then exits the application.
|
2016-02-23 09:27:32 +01:00
|
|
|
func errorQuit(err error) {
|
2016-02-24 07:34:19 +01:00
|
|
|
fmt.Println(fmt.Sprintf("Error: %s\n---\n\nRun 'sqlboiler --help' for usage.", err))
|
2016-02-23 09:27:32 +01:00
|
|
|
os.Exit(-1)
|
|
|
|
}
|
|
|
|
|
2016-02-29 10:39:49 +01:00
|
|
|
// defaultRun is the default function passed to the commands cobra.Command.Run.
|
|
|
|
// It will generate the specific commands template and send it to outHandler for output.
|
|
|
|
func defaultRun(cmd *cobra.Command, args []string) {
|
2016-03-01 15:20:13 +01:00
|
|
|
// Generate the template for every table
|
2016-03-02 05:05:25 +01:00
|
|
|
for i := 0; i < len(cmdData.Columns); i++ {
|
2016-03-01 15:20:13 +01:00
|
|
|
data := tplData{
|
2016-03-02 05:05:25 +01:00
|
|
|
Table: cmdData.Tables[i],
|
|
|
|
Columns: cmdData.Columns[i],
|
2016-03-16 15:33:58 +01:00
|
|
|
PkgName: cmdData.PkgName,
|
2016-03-01 15:20:13 +01:00
|
|
|
}
|
|
|
|
|
2016-03-18 16:27:55 +01:00
|
|
|
templater(cmd, &data)
|
|
|
|
}
|
|
|
|
}
|
2016-03-01 15:20:13 +01:00
|
|
|
|
2016-03-18 16:27:55 +01:00
|
|
|
// templater generates the template by passing it the tplData object.
|
|
|
|
// Once the template is generated, it will add the imports to the output stream
|
|
|
|
// and output the contents of the template with the added bits (imports and package declaration).
|
|
|
|
func templater(cmd *cobra.Command, data *tplData) {
|
|
|
|
// outHandler takes a slice of byte slices, so append the Template
|
|
|
|
// execution output to a [][]byte before sending it to outHandler.
|
|
|
|
out := [][]byte{generateTemplate(cmd.Name(), data)}
|
|
|
|
|
|
|
|
imps := combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[cmd.Name()])
|
2016-03-23 04:03:35 +01:00
|
|
|
imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Columns)
|
|
|
|
|
2016-03-21 06:15:14 +01:00
|
|
|
err := outHandler(cmdData.OutFolder, out, data, &imps, false)
|
2016-03-18 16:27:55 +01:00
|
|
|
if err != nil {
|
|
|
|
errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err))
|
2016-02-24 06:40:07 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-03-02 06:37:14 +01:00
|
|
|
var testHarnessStdout io.Writer = os.Stdout
|
|
|
|
var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) {
|
|
|
|
file, err := os.Create(filename)
|
|
|
|
return file, err
|
|
|
|
}
|
|
|
|
|
2016-02-24 09:53:34 +01:00
|
|
|
// outHandler loops over the slice of byte slices, outputting them to either
|
|
|
|
// the OutFile if it is specified with a flag, or to Stdout if no flag is specified.
|
2016-03-21 06:15:14 +01:00
|
|
|
func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports, testTemplate bool) error {
|
2016-03-02 06:37:14 +01:00
|
|
|
out := testHarnessStdout
|
2016-02-24 06:40:07 +01:00
|
|
|
|
2016-03-03 20:30:48 +01:00
|
|
|
var path string
|
2016-03-02 06:37:14 +01:00
|
|
|
if len(outFolder) != 0 {
|
2016-03-21 06:15:14 +01:00
|
|
|
if testTemplate {
|
|
|
|
path = outFolder + "/" + data.Table + "_test.go"
|
|
|
|
} else {
|
|
|
|
path = outFolder + "/" + data.Table + ".go"
|
|
|
|
}
|
|
|
|
|
2016-03-02 06:37:14 +01:00
|
|
|
outFile, err := testHarnessFileOpen(path)
|
2016-03-01 15:20:13 +01:00
|
|
|
if err != nil {
|
|
|
|
errorQuit(fmt.Errorf("Unable to create output file %s: %s", path, err))
|
2016-02-24 06:40:07 +01:00
|
|
|
}
|
2016-03-02 06:37:14 +01:00
|
|
|
defer outFile.Close()
|
|
|
|
out = outFile
|
|
|
|
}
|
2016-03-01 15:20:13 +01:00
|
|
|
|
2016-03-03 20:30:48 +01:00
|
|
|
if _, err := fmt.Fprintf(out, "package %s\n\n", cmdData.PkgName); err != nil {
|
|
|
|
errorQuit(fmt.Errorf("Unable to write package name %s to file: %s", cmdData.PkgName, path))
|
|
|
|
}
|
|
|
|
|
2016-03-02 16:18:26 +01:00
|
|
|
impStr := buildImportString(imps)
|
|
|
|
if len(impStr) > 0 {
|
|
|
|
if _, err := fmt.Fprintf(out, "%s\n", impStr); err != nil {
|
|
|
|
errorQuit(fmt.Errorf("Unable to write imports to file handle: %v", err))
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2016-03-02 06:37:14 +01:00
|
|
|
for _, templateOutput := range output {
|
|
|
|
if _, err := fmt.Fprintf(out, "%s\n", templateOutput); err != nil {
|
|
|
|
errorQuit(fmt.Errorf("Unable to write template output to file handle: %v", err))
|
2016-02-24 06:40:07 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return nil
|
|
|
|
}
|
2016-03-02 07:34:57 +01:00
|
|
|
|
|
|
|
func combineStringSlices(a, b []string) []string {
|
|
|
|
c := make([]string, len(a)+len(b))
|
|
|
|
if len(a) > 0 {
|
|
|
|
copy(c, a)
|
|
|
|
}
|
|
|
|
if len(b) > 0 {
|
|
|
|
copy(c[len(a):], b)
|
|
|
|
}
|
|
|
|
|
|
|
|
return c
|
|
|
|
}
|
|
|
|
|
|
|
|
func removeDuplicates(dedup []string) []string {
|
|
|
|
if len(dedup) <= 1 {
|
|
|
|
return dedup
|
|
|
|
}
|
|
|
|
|
|
|
|
for i := 0; i < len(dedup)-1; i++ {
|
|
|
|
for j := i + 1; j < len(dedup); j++ {
|
|
|
|
if dedup[i] != dedup[j] {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
if j != len(dedup)-1 {
|
|
|
|
dedup[j] = dedup[len(dedup)-1]
|
|
|
|
j--
|
|
|
|
}
|
|
|
|
dedup = dedup[:len(dedup)-1]
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
return dedup
|
|
|
|
}
|