diff --git a/README.md b/README.md index f3cf24c..579d550 100644 --- a/README.md +++ b/README.md @@ -22,6 +22,7 @@ lifecycle. - Easy workflow (models can always be regenerated, full auto-complete) - Strongly typed querying (usually no converting or binding to pointers) - Hooks (Before/After Create/Update) +- Automatic CreatedAt/UpdatedAt - Relationships/Associations - Eager loading - Transactions @@ -29,17 +30,20 @@ lifecycle. - Compatibility tests (Run against your own DB schema) - Debug logging -#### Missing Features - -- Automatic CreatedAt UpdatedAt (use Hooks instead) -- Nested eager loading - #### Supported Databases - PostgreSQL Note: Seeking contributors for other database engines. +#### Automatic CreatedAt/UpdatedAt + +If your generated SQLBoiler models package can find columns with the +names `created_at` or `updated_at` it will automatically set them +to `time.Now()` in your database, and update your object appropriately. + +Note: You can set the timezone for this feature by calling `boil.SetLocation()` + #### Example Queries ```go diff --git a/boil/db.go b/boil/db.go index e74c01e..84bc7d3 100644 --- a/boil/db.go +++ b/boil/db.go @@ -1,14 +1,6 @@ package boil -import ( - "database/sql" - "os" -) - -var ( - // currentDB is a global database handle for the package - currentDB Executor -) +import "database/sql" // Executor can perform SQL queries. type Executor interface { @@ -30,15 +22,6 @@ type Beginner interface { Begin() (*sql.Tx, error) } -// DebugMode is a flag controlling whether generated sql statements and -// debug information is outputted to the DebugWriter handle -// -// NOTE: This should be disabled in production to avoid leaking sensitive data -var DebugMode = false - -// DebugWriter is where the debug output will be sent if DebugMode is true -var DebugWriter = os.Stdout - // Begin a transaction func Begin() (Transactor, error) { creator, ok := currentDB.(Beginner) @@ -48,13 +31,3 @@ func Begin() (Transactor, error) { return creator.Begin() } - -// SetDB initializes the database handle for all template db interactions -func SetDB(db Executor) { - currentDB = db -} - -// GetDB retrieves the global state database handle -func GetDB() Executor { - return currentDB -} diff --git a/boil/global.go b/boil/global.go new file mode 100644 index 0000000..7eae7e6 --- /dev/null +++ b/boil/global.go @@ -0,0 +1,50 @@ +package boil + +import ( + "os" + "time" +) + +var ( + // currentDB is a global database handle for the package + currentDB Executor + // timestampLocation is the timezone used for the + // automated setting of created_at/updated_at columns + timestampLocation *time.Location +) + +// DebugMode is a flag controlling whether generated sql statements and +// debug information is outputted to the DebugWriter handle +// +// NOTE: This should be disabled in production to avoid leaking sensitive data +var DebugMode = false + +// DebugWriter is where the debug output will be sent if DebugMode is true +var DebugWriter = os.Stdout + +// SetDB initializes the database handle for all template db interactions +func SetDB(db Executor) { + currentDB = db +} + +// GetDB retrieves the global state database handle +func GetDB() Executor { + return currentDB +} + +// SetLocation sets the global timestamp Location. +// This is the timezone used by the generated package for the +// automated setting of created_at and updated_at columns. +// If the package was generated with the --no-auto-timestamps flag +// then this function has no effect. +func SetLocation(loc *time.Location) { + timestampLocation = loc +} + +// GetLocation retrieves the global timestamp Location. +// This is the timezone used by the generated package for the +// automated setting of created_at and updated_at columns +// if the package was not generated with the --no-auto-timestamps flag. +func GetLocation() *time.Location { + return timestampLocation +} diff --git a/config.go b/config.go index 0675399..349c90c 100644 --- a/config.go +++ b/config.go @@ -2,12 +2,13 @@ package main // Config for the running of the commands type Config struct { - DriverName string `toml:"driver_name"` - PkgName string `toml:"pkg_name"` - OutFolder string `toml:"out_folder"` - BaseDir string `toml:"base_dir"` - ExcludeTables []string `toml:"exclude"` - NoHooks bool `toml:"no_hooks"` + DriverName string `toml:"driver_name"` + PkgName string `toml:"pkg_name"` + OutFolder string `toml:"out_folder"` + BaseDir string `toml:"base_dir"` + ExcludeTables []string `toml:"exclude"` + NoHooks bool `toml:"no_hooks"` + NoAutoTimestamps bool `toml:"no_auto_timestamps"` Postgres PostgresConfig `toml:"postgres"` } diff --git a/imports.go b/imports.go index fa4424e..8ace24f 100644 --- a/imports.go +++ b/imports.go @@ -146,6 +146,7 @@ var defaultTemplateImports = imports{ `"fmt"`, `"strings"`, `"database/sql"`, + `"time"`, }, thirdParty: importList{ `"github.com/pkg/errors"`, diff --git a/main.go b/main.go index 3a8a2ec..e365a01 100644 --- a/main.go +++ b/main.go @@ -66,6 +66,7 @@ func main() { rootCmd.PersistentFlags().StringSliceP("exclude", "x", nil, "Tables to be excluded from the generated package") rootCmd.PersistentFlags().BoolP("debug", "d", false, "Debug mode prints stack traces on error") rootCmd.PersistentFlags().BoolP("no-hooks", "", false, "Disable hooks feature for your models") + rootCmd.PersistentFlags().BoolP("no-auto-timestamps", "", false, "Disable automatic timestamps for created_at/updated_at") viper.SetDefault("postgres.sslmode", "require") viper.SetDefault("postgres.port", "5432") @@ -102,10 +103,11 @@ func preRun(cmd *cobra.Command, args []string) error { driverName := args[0] cmdConfig = &Config{ - DriverName: driverName, - OutFolder: viper.GetString("output"), - PkgName: viper.GetString("pkgname"), - NoHooks: viper.GetBool("no-hooks"), + DriverName: driverName, + OutFolder: viper.GetString("output"), + PkgName: viper.GetString("pkgname"), + NoHooks: viper.GetBool("no-hooks"), + NoAutoTimestamps: viper.GetBool("no-auto-timestamps"), } // BUG: https://github.com/spf13/viper/issues/200 diff --git a/sqlboiler.go b/sqlboiler.go index 9d0be57..4d17322 100644 --- a/sqlboiler.go +++ b/sqlboiler.go @@ -77,11 +77,12 @@ func New(config *Config) (*State, error) { // state given. func (s *State) Run(includeTests bool) error { singletonData := &templateData{ - Tables: s.Tables, - DriverName: s.Config.DriverName, - UseLastInsertID: s.Driver.UseLastInsertID(), - PkgName: s.Config.PkgName, - NoHooks: s.Config.NoHooks, + Tables: s.Tables, + DriverName: s.Config.DriverName, + UseLastInsertID: s.Driver.UseLastInsertID(), + PkgName: s.Config.PkgName, + NoHooks: s.Config.NoHooks, + NoAutoTimestamps: s.Config.NoAutoTimestamps, StringFuncs: templateStringMappers, } @@ -106,12 +107,13 @@ func (s *State) Run(includeTests bool) error { } data := &templateData{ - Tables: s.Tables, - Table: table, - DriverName: s.Config.DriverName, - UseLastInsertID: s.Driver.UseLastInsertID(), - PkgName: s.Config.PkgName, - NoHooks: s.Config.NoHooks, + Tables: s.Tables, + Table: table, + DriverName: s.Config.DriverName, + UseLastInsertID: s.Driver.UseLastInsertID(), + PkgName: s.Config.PkgName, + NoHooks: s.Config.NoHooks, + NoAutoTimestamps: s.Config.NoAutoTimestamps, StringFuncs: templateStringMappers, } diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 3c04da5..ca78531 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -411,3 +411,17 @@ func StringSliceMatch(a []string, b []string) bool { return true } + +// ContainsAny returns true if any of the passed in strings are +// found in the passed in string slice +func ContainsAny(a []string, finds ...string) bool { + for _, s := range a { + for _, find := range finds { + if s == find { + return true + } + } + } + + return false +} diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index 862b221..03f777a 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -378,3 +378,32 @@ func TestStringSliceMatch(t *testing.T) { } } } + +func TestContainsAny(t *testing.T) { + t.Parallel() + + a := []string{"hello", "friend"} + if ContainsAny([]string{}, "x") { + t.Errorf("Should not contain x") + } + + if ContainsAny(a, "x") { + t.Errorf("Should not contain x") + } + + if !ContainsAny(a, "hello") { + t.Errorf("Should contain hello") + } + + if !ContainsAny(a, "friend") { + t.Errorf("Should contain friend") + } + + if !ContainsAny(a, "hello", "friend") { + t.Errorf("Should contain hello and friend") + } + + if ContainsAny(a) { + t.Errorf("Should not return true") + } +} diff --git a/templates.go b/templates.go index cc4334c..a2b7db3 100644 --- a/templates.go +++ b/templates.go @@ -13,12 +13,13 @@ import ( // templateData for sqlboiler templates type templateData struct { - Tables []bdb.Table - Table bdb.Table - DriverName string - UseLastInsertID bool - PkgName string - NoHooks bool + Tables []bdb.Table + Table bdb.Table + DriverName string + UseLastInsertID bool + PkgName string + NoHooks bool + NoAutoTimestamps bool StringFuncs map[string]func(string) string } @@ -127,6 +128,7 @@ var templateFunctions = template.FuncMap{ "joinSlices": strmangle.JoinSlices, "stringMap": strmangle.StringMap, "prefixStringSlice": strmangle.PrefixStringSlice, + "containsAny": strmangle.ContainsAny, // String Map ops "makeStringMap": strmangle.MakeStringMap, diff --git a/templates/01_types.tpl b/templates/01_types.tpl index 1a289c0..eb62947 100644 --- a/templates/01_types.tpl +++ b/templates/01_types.tpl @@ -22,3 +22,6 @@ type ( *boil.Query } ) + +// Force time package dependency for automated UpdatedAt/CreatedAt. +var _ = time.Second diff --git a/templates/10_insert.tpl b/templates/10_insert.tpl index ce05456..a486b26 100644 --- a/templates/10_insert.tpl +++ b/templates/10_insert.tpl @@ -32,6 +32,8 @@ func (o *{{$tableNameSingular}}) Insert(exec boil.Executor, whitelist ... string } var err error + {{- template "timestamp_insert_helper" . }} + {{if eq .NoHooks false -}} if err := o.doBeforeInsertHooks(); err != nil { return err diff --git a/templates/11_update.tpl b/templates/11_update.tpl index 970e1e1..4152335 100644 --- a/templates/11_update.tpl +++ b/templates/11_update.tpl @@ -35,6 +35,8 @@ func (o *{{$tableNameSingular}}) UpdateP(exec boil.Executor, whitelist ... strin // Update does not automatically update the record in case of default values. Use .Reload() // to refresh the records. func (o *{{$tableNameSingular}}) Update(exec boil.Executor, whitelist ... string) error { + {{- template "timestamp_update_helper" . -}} + {{if eq .NoHooks false -}} if err := o.doBeforeUpdateHooks(); err != nil { return err diff --git a/templates/17_auto_timestamps.tpl b/templates/17_auto_timestamps.tpl new file mode 100644 index 0000000..4adae76 --- /dev/null +++ b/templates/17_auto_timestamps.tpl @@ -0,0 +1,56 @@ +{{- define "timestamp_insert_helper" -}} + {{- if eq .NoAutoTimestamps false -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "created_at" "updated_at"}} + loc := boil.GetLocation() + currTime := time.Time{} + if loc != nil { + currTime = time.Now().In(boil.GetLocation()) + } else { + currTime = time.Now() + } + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "created_at" -}} + {{- if $col.Nullable}} + o.CreatedAt.Time = currTime + o.CreatedAt.Valid = true + {{- else}} + o.CreatedAt = currTime + {{- end -}} + {{- end -}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + {{- else}} + o.UpdatedAt = currTime + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} +{{- end -}} +{{- define "timestamp_update_helper" -}} + {{- if eq .NoAutoTimestamps false -}} + {{- $colNames := .Table.Columns | columnNames -}} + {{if containsAny $colNames "updated_at"}} + loc := boil.GetLocation() + currTime := time.Time{} + if loc != nil { + currTime = time.Now().In(boil.GetLocation()) + } else { + currTime = time.Now() + } + {{range $ind, $col := .Table.Columns}} + {{- if eq $col.Name "updated_at" -}} + {{- if $col.Nullable}} + o.UpdatedAt.Time = currTime + o.UpdatedAt.Valid = true + {{- else}} + o.UpdatedAt = currTime + {{- end -}} + {{- end -}} + {{end}} + {{end}} + {{- end}} +{{end -}}