diff --git a/cmds/boil.go b/cmds/boil.go index 3a22f76..6deae25 100644 --- a/cmds/boil.go +++ b/cmds/boil.go @@ -22,10 +22,9 @@ func boilRun(cmd *cobra.Command, args []string) { // the main template initializes all of the testing assets testCommandNames := append([]string{"main"}, commandNames...) - for i := 0; i < len(cmdData.Columns); i++ { - data := tplData{ - Table: cmdData.Tables[i], - Columns: cmdData.Columns[i], + for _, table := range cmdData.Tables { + data := &tplData{ + Table: table, PkgName: cmdData.PkgName, } @@ -38,11 +37,11 @@ func boilRun(cmd *cobra.Command, args []string) { // Loop through and generate every command template (excluding skipTemplates) for _, command := range commandNames { imps = combineImports(imps, sqlBoilerCustomImports[command]) - imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Columns) - out = append(out, generateTemplate(command, &data)) + imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Table.Columns) + out = append(out, generateTemplate(command, data)) } - err := outHandler(cmdData.OutFolder, out, &data, &imps, false) + err := outHandler(cmdData.OutFolder, out, data, imps, false) if err != nil { errorQuit(err) } @@ -58,10 +57,10 @@ func boilRun(cmd *cobra.Command, args []string) { // Loop through and generate every command test template (excluding skipTemplates) for _, command := range testCommandNames { testImps = combineImports(testImps, sqlBoilerCustomTestImports[command]) - testOut = append(testOut, generateTestTemplate(command, &data)) + testOut = append(testOut, generateTestTemplate(command, data)) } - err = outHandler(cmdData.OutFolder, testOut, &data, &testImps, true) + err = outHandler(cmdData.OutFolder, testOut, data, testImps, true) if err != nil { errorQuit(err) } diff --git a/cmds/commands.go b/cmds/commands.go index f7b0180..d4a9e33 100644 --- a/cmds/commands.go +++ b/cmds/commands.go @@ -34,7 +34,7 @@ var sqlBoilerDefaultTestImports = imports{ } // sqlBoilerConditionalTypeImports imports are only included in the template output -// if the database requires one of the following special types. Check TranslateColumn +// if the database requires one of the following special types. Check TranslateColumnType // to see the type assignments. var sqlBoilerConditionalTypeImports = map[string]imports{ "null.Int": imports{ diff --git a/cmds/imports.go b/cmds/imports.go index afe50eb..fc3a6b4 100644 --- a/cmds/imports.go +++ b/cmds/imports.go @@ -9,8 +9,6 @@ import ( "github.com/pobri19/sqlboiler/dbdrivers" ) -type ImportSorter []string - func (i importList) Len() int { return len(i) } @@ -67,7 +65,7 @@ func combineConditionalTypeImports(a imports, b map[string]imports, columns []db return tmpImp } -func buildImportString(imps *imports) []byte { +func buildImportString(imps imports) []byte { stdlen, thirdlen := len(imps.standard), len(imps.thirdparty) if stdlen+thirdlen < 1 { return []byte{} diff --git a/cmds/shared.go b/cmds/shared.go index 4e40e19..98f6e02 100644 --- a/cmds/shared.go +++ b/cmds/shared.go @@ -17,17 +17,15 @@ type CobraRunFunc func(cmd *cobra.Command, args []string) // the database driver chosen by the driver flag at runtime, and a pointer to the // output file, if one is specified with a flag. type CmdData struct { - Tables []string - Columns [][]dbdrivers.Column + Tables []dbdrivers.Table PkgName string OutFolder string - Interface dbdrivers.Interface + Interface dbdrivers.Interface } // tplData is used to pass data to the template type tplData struct { - Table string - Columns []dbdrivers.Column + Table dbdrivers.Table PkgName string } @@ -41,14 +39,13 @@ func errorQuit(err error) { // It will generate the specific commands template and send it to outHandler for output. func defaultRun(cmd *cobra.Command, args []string) { // Generate the template for every table - for i := 0; i < len(cmdData.Columns); i++ { - data := tplData{ - Table: cmdData.Tables[i], - Columns: cmdData.Columns[i], + for _, t := range cmdData.Tables { + data := &tplData{ + Table: t, PkgName: cmdData.PkgName, } - templater(cmd, &data) + templater(cmd, data) } } @@ -61,9 +58,9 @@ func templater(cmd *cobra.Command, data *tplData) { out := [][]byte{generateTemplate(cmd.Name(), data)} imps := combineImports(sqlBoilerDefaultImports, sqlBoilerCustomImports[cmd.Name()]) - imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Columns) + imps = combineConditionalTypeImports(imps, sqlBoilerConditionalTypeImports, data.Table.Columns) - err := outHandler(cmdData.OutFolder, out, data, &imps, false) + err := outHandler(cmdData.OutFolder, out, data, imps, false) if err != nil { errorQuit(fmt.Errorf("Unable to generate the template for command %s: %s", cmd.Name(), err)) } @@ -77,15 +74,15 @@ var testHarnessFileOpen = func(filename string) (io.WriteCloser, error) { // outHandler loops over the slice of byte slices, outputting them to either // the OutFile if it is specified with a flag, or to Stdout if no flag is specified. -func outHandler(outFolder string, output [][]byte, data *tplData, imps *imports, testTemplate bool) error { +func outHandler(outFolder string, output [][]byte, data *tplData, imps imports, testTemplate bool) error { out := testHarnessStdout var path string if len(outFolder) != 0 { if testTemplate { - path = outFolder + "/" + data.Table + "_test.go" + path = outFolder + "/" + data.Table.Name + "_test.go" } else { - path = outFolder + "/" + data.Table + ".go" + path = outFolder + "/" + data.Table.Name + ".go" } outFile, err := testHarnessFileOpen(path) diff --git a/cmds/sqlboiler.go b/cmds/sqlboiler.go index 6b33cdb..517a9a0 100644 --- a/cmds/sqlboiler.go +++ b/cmds/sqlboiler.go @@ -79,9 +79,6 @@ func sqlBoilerPreRun(cmd *cobra.Command, args []string) { // Initialize the cmdData.Tables initTables() - // Initialize the cmdData.Columns - initColumns() - // Initialize the package name initPkgName() @@ -136,39 +133,23 @@ func initTables() { // Retrieve the list of tables tn := SQLBoiler.PersistentFlags().Lookup("table").Value.String() + var tableNames []string + if len(tn) != 0 { - cmdData.Tables = strings.Split(tn, ",") - for i, name := range cmdData.Tables { - cmdData.Tables[i] = strings.TrimSpace(name) + tableNames = strings.Split(tn, ",") + for i, name := range tableNames { + tableNames[i] = strings.TrimSpace(name) } } - // If no table names are provided attempt to process all tables in database + var err error + cmdData.Tables, err = cmdData.Interface.Tables(tableNames...) + if err != nil { + errorQuit(fmt.Errorf("Unable to get all table names: %s", err)) + } + if len(cmdData.Tables) == 0 { - // get all table names - var err error - cmdData.Tables, err = cmdData.Interface.AllTables() - if err != nil { - errorQuit(fmt.Errorf("Unable to get all table names: %s", err)) - } - - if len(cmdData.Tables) == 0 { - errorQuit(errors.New("No tables found in database, migrate some tables first")) - } - } -} - -// initColumns builds a description of each table (column name, column type) -// and assigns it to cmdData.Columns, the slice of dbdrivers.Column slices. -func initColumns() { - // loop over table Names and build Columns - for i := 0; i < len(cmdData.Tables); i++ { - tInfo, err := cmdData.Interface.Columns(cmdData.Tables[i]) - if err != nil { - errorQuit(fmt.Errorf("Unable to get the table info: %s", err)) - } - - cmdData.Columns = append(cmdData.Columns, tInfo) + errorQuit(errors.New("No tables found in database, migrate some tables first")) } } diff --git a/dbdrivers/interface.go b/dbdrivers/interface.go index 84e7f91..a6ef649 100644 --- a/dbdrivers/interface.go +++ b/dbdrivers/interface.go @@ -3,16 +3,13 @@ package dbdrivers // Interface for a database driver. Functionality required to support a specific // database type (eg, MySQL, Postgres etc.) type Interface interface { - // AllTables connects to the database and retrieves all "public" table names - AllTables() ([]string, error) + // Tables connects to the database and retrieves the table metadata for + // the given tables, or all tables if none are provided. + Tables(names ...string) ([]Table, error) - // Columns retrieves column information about the table. - Columns(tableName string) ([]Column, error) - - // TranslateColumn builds a Column out of a column metadata. - // Its main responsibility is to convert database types to Go types, for - // example "varchar" to "string". - TranslateColumn(Column) Column + // TranslateColumnType takes a Database column type and returns a go column + // type. + TranslateColumnType(Column) Column // Open the database connection Open() error @@ -30,7 +27,7 @@ type Table struct { } // Column holds information about a database column. -// Types are Go types, converted by TranslateColumn. +// Types are Go types, converted by TranslateColumnType. type Column struct { Name string Type string diff --git a/dbdrivers/postgres_driver.go b/dbdrivers/postgres_driver.go index 5e3d207..19a2c3f 100644 --- a/dbdrivers/postgres_driver.go +++ b/dbdrivers/postgres_driver.go @@ -29,9 +29,9 @@ func NewPostgresDriver(user, pass, dbname, host string, port int) *PostgresDrive } // Open opens the database connection using the connection string -func (d *PostgresDriver) Open() error { +func (p *PostgresDriver) Open() error { var err error - d.dbConn, err = sql.Open("postgres", d.connStr) + p.dbConn, err = sql.Open("postgres", p.connStr) if err != nil { return err } @@ -40,18 +40,44 @@ func (d *PostgresDriver) Open() error { } // Close closes the database connection -func (d *PostgresDriver) Close() { - d.dbConn.Close() +func (p *PostgresDriver) Close() { + p.dbConn.Close() } -// AllTables connects to the postgres database and +// Tables returns the table metadata for the given tables, or all tables if +// no tables are provided. +func (p *PostgresDriver) Tables(names ...string) ([]Table, error) { + var err error + if len(names) == 0 { + if names, err = p.tableNames(); err != nil { + return nil, err + } + } + + var tables []Table + for _, name := range names { + columns, err := p.columns(name) + if err != nil { + return nil, err + } + + tables = append(tables, Table{ + Name: name, + Columns: columns, + }) + } + + return tables, nil +} + +// tableNames connects to the postgres database and // retrieves all table names from the information_schema where the // table schema is public. It excludes common migration tool tables // such as gorp_migrations -func (d *PostgresDriver) AllTables() ([]string, error) { - var tableNames []string +func (p *PostgresDriver) tableNames() ([]string, error) { + var names []string - rows, err := d.dbConn.Query(`select table_name from + rows, err := p.dbConn.Query(`select table_name from information_schema.tables where table_schema='public' and table_name <> 'gorp_migrations'`) @@ -61,24 +87,24 @@ func (d *PostgresDriver) AllTables() ([]string, error) { defer rows.Close() for rows.Next() { - var tableName string - if err := rows.Scan(&tableName); err != nil { + var name string + if err := rows.Scan(&name); err != nil { return nil, err } - tableNames = append(tableNames, tableName) + names = append(names, name) } - return tableNames, nil + return names, nil } -// Columns takes a table name and attempts to retrieve the table information +// columns takes a table name and attempts to retrieve the table information // from the database information_schema.columns. It retrieves the column names -// and column types and returns those as a []Column after TranslateColumn() +// and column types and returns those as a []Column after TranslateColumnType() // converts the SQL types to Go types, for example: "varchar" to "string" -func (d *PostgresDriver) Columns(tableName string) ([]Column, error) { - var table []Column +func (p *PostgresDriver) columns(tableName string) ([]Column, error) { + var columns []Column - rows, err := d.dbConn.Query(` + rows, err := p.dbConn.Query(` SELECT c.column_name, c.data_type, c.is_nullable, CASE WHEN pk.column_name IS NOT NULL THEN 'PRIMARY KEY' ELSE '' END AS KeyType FROM information_schema.columns c @@ -104,22 +130,22 @@ func (d *PostgresDriver) Columns(tableName string) ([]Column, error) { if err := rows.Scan(&colName, &colType, &isNullable, &isPrimary); err != nil { return nil, err } - t := d.TranslateColumn(Column{ + column := p.TranslateColumnType(Column{ Name: colName, Type: colType, IsNullable: isNullable == "YES", IsPrimaryKey: isPrimary == "PRIMARY KEY", }) - table = append(table, t) + columns = append(columns, column) } - return table, nil + return columns, nil } -// TranslateColumn converts postgres database types to Go types, for example +// TranslateColumnType converts postgres database types to Go types, for example // "varchar" to "string" and "bigint" to "int64". It returns this parsed data // as a Column object. -func (d *PostgresDriver) TranslateColumn(c Column) Column { +func (p *PostgresDriver) TranslateColumnType(c Column) Column { if c.IsNullable { switch c.Type { case "bigint", "bigserial", "integer", "smallint", "smallserial", "serial":