diff --git a/strmangle/strmangle.go b/strmangle/strmangle.go index 567ef4b..e5f067e 100644 --- a/strmangle/strmangle.go +++ b/strmangle/strmangle.go @@ -318,3 +318,27 @@ func JoinSlices(sep string, a, b []string) []string { return ret } + +// StringSliceMatch returns true if the length of both +// slices is the same, and the elements of both slices are the same. +// The elements can be in any order. +func StringSliceMatch(a []string, b []string) bool { + if len(a) != len(b) { + return false + } + + for _, aval := range a { + found := false + for _, bval := range b { + if bval == aval { + found = true + break + } + } + if !found { + return false + } + } + + return true +} diff --git a/strmangle/strmangle_test.go b/strmangle/strmangle_test.go index 1bb6333..bfdbfb8 100644 --- a/strmangle/strmangle_test.go +++ b/strmangle/strmangle_test.go @@ -297,3 +297,60 @@ func TestJoinSlicesFail(t *testing.T) { JoinSlices("", nil, []string{"hello"}) } + +func TestStringSliceMatch(t *testing.T) { + t.Parallel() + + tests := []struct { + a []string + b []string + expect bool + }{ + { + a: []string{}, + b: []string{}, + expect: true, + }, + { + a: []string{"a"}, + b: []string{}, + expect: false, + }, + { + a: []string{"a"}, + b: []string{"a"}, + expect: true, + }, + { + a: []string{}, + b: []string{"b"}, + expect: false, + }, + { + a: []string{"c", "d"}, + b: []string{"b", "d"}, + expect: false, + }, + { + a: []string{"b", "d"}, + b: []string{"c", "d"}, + expect: false, + }, + { + a: []string{"a", "b", "c"}, + b: []string{"c", "b", "a"}, + expect: true, + }, + { + a: []string{"a", "b", "c"}, + b: []string{"a", "b", "c"}, + expect: true, + }, + } + + for i, test := range tests { + if StringSliceMatch(test.a, test.b) != test.expect { + t.Errorf("%d) Expected match to return %v, but got %v", i, test.expect, !test.expect) + } + } +} diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index 7df71a2..ab01981 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -29,12 +29,12 @@ func generateUpsertQuery(tableName string, updateOnConflict bool, ret, update, c strmangle.Placeholders(len(whitelist), 1, 1), ) - if !updateOnConflict { + if !updateOnConflict || len(update) == 0 { buf.WriteString("DO NOTHING") } else { buf.WriteByte('(') buf.WriteString(strings.Join(conflict, ", ")) - buf.WriteString(") DO UPDATE SET") + buf.WriteString(") DO UPDATE SET ") for i, v := range update { if i != 0 { diff --git a/templates_test/helpers.tpl b/templates_test/helpers.tpl index 1b1b9a1..c20dbd0 100644 --- a/templates_test/helpers.tpl +++ b/templates_test/helpers.tpl @@ -49,11 +49,13 @@ func test{{$tableNamePlural}}SliceInPrimaryKeyArgs(t *testing.T) { t.Errorf("Expected args to be len %d, but got %d", len({{$varNameSingular}}PrimaryKeyColumns) * 3, len(args)) } - for i := 0; i < len({{$varNameSingular}}PrimaryKeyColumns) * 3; i++ { + argC := 0 + for i := 0; i < 3; i++ { {{range $key, $value := .Table.PKey.Columns}} - if o[i].{{titleCase $value}} != args[i] { + if o[i].{{titleCase $value}} != args[argC] { t.Errorf("Expected args[%d] to be value of o.{{titleCase $value}}, but got %#v", i, args[i]) } + argC++ {{- end}} } } diff --git a/templates_test/update.tpl b/templates_test/update.tpl index 6e2d645..5cd15be 100644 --- a/templates_test/update.tpl +++ b/templates_test/update.tpl @@ -31,8 +31,18 @@ func test{{$tableNamePlural}}Update(t *testing.T) { t.Errorf("Unable to randomize {{$tableNameSingular}} struct: %s", err) } - if err = {{$varNameSingular}}.Update(tx); err != nil { - t.Error(err) + // If table only contains primary key columns, we need to pass + // them into a whitelist to get a valid test result, + // otherwise the Update method will error because it will not be able to + // generate a whitelist (due to it excluding primary key columns). + if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { + if err = {{$varNameSingular}}.Update(tx, {{$varNameSingular}}PrimaryKeyColumns...); err != nil { + t.Error(err) + } + } else { + if err = {{$varNameSingular}}.Update(tx); err != nil { + t.Error(err) + } } } @@ -66,10 +76,15 @@ func test{{$tableNamePlural}}SliceUpdateAll(t *testing.T) { } // Remove Primary keys and unique columns from what we plan to update - fields := strmangle.SetComplement( - {{$varNameSingular}}Columns, - {{$varNameSingular}}PrimaryKeyColumns, - ) + var fields []string + if strmangle.StringSliceMatch({{$varNameSingular}}Columns, {{$varNameSingular}}PrimaryKeyColumns) { + fields = {{$varNameSingular}}Columns + } else { + fields = strmangle.SetComplement( + {{$varNameSingular}}Columns, + {{$varNameSingular}}PrimaryKeyColumns, + ) + } value := reflect.Indirect(reflect.ValueOf({{$varNameSingular}})) updateMap := M{} diff --git a/testdata/test_schema.sql b/testdata/test_schema.sql index 717aff6..3d63851 100644 --- a/testdata/test_schema.sql +++ b/testdata/test_schema.sql @@ -117,3 +117,21 @@ create table cat_toys ( toy_id int not null references toys (id), primary key (cat_id, toy_id) ); + +create table dog_toys ( + dog_id int not null, + toy_id int not null, + primary key (dog_id, toy_id) +); + +create table dragon_toys ( + dragon_id uuid, + toy_id uuid, + primary key (dragon_id, toy_id) +); + +create table spider_toys ( + spider_id uuid, + name character varying, + primary key (spider_id) +);