diff --git a/bdb/drivers/mock.go b/bdb/drivers/mock.go index 7d14f2a..810eb48 100644 --- a/bdb/drivers/mock.go +++ b/bdb/drivers/mock.go @@ -58,6 +58,10 @@ func (m *MockDriver) Columns(schema, tableName string) ([]bdb.Column, error) { }[tableName], nil } +func (m *MockDriver) UniqueKeyInfo(schema, tableName string) ([]bdb.UniqueKey, error) { + return []bdb.UniqueKey{}, nil +} + // ForeignKeyInfo returns a list of mock foreignkeys func (m *MockDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { return map[string][]bdb.ForeignKey{ diff --git a/bdb/drivers/mssql.go b/bdb/drivers/mssql.go index e59330b..70b9ed9 100644 --- a/bdb/drivers/mssql.go +++ b/bdb/drivers/mssql.go @@ -7,8 +7,8 @@ import ( "strings" _ "github.com/denisenkom/go-mssqldb" - "github.com/pkg/errors" "github.com/lbryio/sqlboiler/bdb" + "github.com/pkg/errors" ) // MSSQLDriver holds the database connection string and a handle @@ -241,6 +241,10 @@ func (m *MSSQLDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, return pkey, nil } +func (m *MSSQLDriver) UniqueKeyInfo(schema, tableName string) ([]bdb.UniqueKey, error) { + return []bdb.UniqueKey{}, errors.New("not implemented") +} + // ForeignKeyInfo retrieves the foreign keys for a given table name. func (m *MSSQLDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { var fkeys []bdb.ForeignKey diff --git a/bdb/drivers/mysql.go b/bdb/drivers/mysql.go index bb73574..6a20ba0 100644 --- a/bdb/drivers/mysql.go +++ b/bdb/drivers/mysql.go @@ -7,8 +7,8 @@ import ( "strings" "github.com/go-sql-driver/mysql" - "github.com/pkg/errors" "github.com/lbryio/sqlboiler/bdb" + "github.com/pkg/errors" ) // TinyintAsBool is a global that is set from main.go if a user specifies @@ -232,6 +232,46 @@ func (m *MySQLDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryKey, return pkey, nil } +// UniqueKeyInfo retrieves the unique keys for a given table name. +func (m *MySQLDriver) UniqueKeyInfo(schema, tableName string) ([]bdb.UniqueKey, error) { + var ukeys []bdb.UniqueKey + + query := ` + select tc.table_name, tc.constraint_name, GROUP_CONCAT(kcu.column_name) + from information_schema.table_constraints tc + left join information_schema.key_column_usage kcu on tc.constraint_name = kcu.constraint_name and tc.table_name = kcu.table_name and tc.table_schema = kcu.table_schema + where tc.table_schema = ? and tc.table_name = ? and tc.constraint_type = "UNIQUE" + group by tc.table_name, tc.constraint_name + ` + + var rows *sql.Rows + var err error + if rows, err = m.dbConn.Query(query, schema, tableName); err != nil { + return nil, err + } + + for rows.Next() { + var ukey bdb.UniqueKey + var columns string + + //ukey.Table = tableName + err = rows.Scan(&ukey.Table, &ukey.Name, &columns) + if err != nil { + return nil, err + } + + ukey.Columns = strings.Split(columns, ",") + + ukeys = append(ukeys, ukey) + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return ukeys, nil +} + // ForeignKeyInfo retrieves the foreign keys for a given table name. func (m *MySQLDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { var fkeys []bdb.ForeignKey diff --git a/bdb/drivers/postgres.go b/bdb/drivers/postgres.go index ad73a2e..4b6a9f4 100644 --- a/bdb/drivers/postgres.go +++ b/bdb/drivers/postgres.go @@ -8,10 +8,10 @@ import ( // Side-effect import sql driver - _ "github.com/lib/pq" - "github.com/pkg/errors" "github.com/lbryio/sqlboiler/bdb" "github.com/lbryio/sqlboiler/strmangle" + _ "github.com/lib/pq" + "github.com/pkg/errors" ) // PostgresDriver holds the database connection string and a handle @@ -266,6 +266,10 @@ func (p *PostgresDriver) PrimaryKeyInfo(schema, tableName string) (*bdb.PrimaryK return pkey, nil } +func (p *PostgresDriver) UniqueKeyInfo(schema, tableName string) ([]bdb.UniqueKey, error) { + return []bdb.UniqueKey{}, errors.New("not implemented") +} + // ForeignKeyInfo retrieves the foreign keys for a given table name. func (p *PostgresDriver) ForeignKeyInfo(schema, tableName string) ([]bdb.ForeignKey, error) { var fkeys []bdb.ForeignKey diff --git a/bdb/interface.go b/bdb/interface.go index dfb33bc..1fe179a 100644 --- a/bdb/interface.go +++ b/bdb/interface.go @@ -9,6 +9,7 @@ type Interface interface { TableNames(schema string, whitelist, blacklist []string) ([]string, error) Columns(schema, tableName string) ([]Column, error) PrimaryKeyInfo(schema, tableName string) (*PrimaryKey, error) + UniqueKeyInfo(schema, tableName string) ([]UniqueKey, error) ForeignKeyInfo(schema, tableName string) ([]ForeignKey, error) // TranslateColumnType takes a Database column type and returns a go column type. @@ -63,6 +64,10 @@ func Tables(db Interface, schema string, whitelist, blacklist []string) ([]Table return nil, errors.Wrapf(err, "unable to fetch table pkey info (%s)", name) } + if t.UKeys, err = db.UniqueKeyInfo(schema, name); err != nil { + return nil, errors.Wrapf(err, "unable to fetch table ukey info (%s)", name) + } + if t.FKeys, err = db.ForeignKeyInfo(schema, name); err != nil { return nil, errors.Wrapf(err, "unable to fetch table fkey info (%s)", name) } diff --git a/bdb/keys.go b/bdb/keys.go index 909ada2..8007843 100644 --- a/bdb/keys.go +++ b/bdb/keys.go @@ -8,6 +8,13 @@ type PrimaryKey struct { Columns []string } +// UniqueKey represents a unique key constraint in a database +type UniqueKey struct { + Table string + Name string + Columns []string +} + // ForeignKey represents a foreign key constraint in a database type ForeignKey struct { Table string diff --git a/bdb/table.go b/bdb/table.go index 62dcd32..d5cbb50 100644 --- a/bdb/table.go +++ b/bdb/table.go @@ -11,6 +11,7 @@ type Table struct { Columns []Column PKey *PrimaryKey + UKeys []UniqueKey FKeys []ForeignKey IsJoinTable bool diff --git a/boilingcore/imports.go b/boilingcore/imports.go index fc148c7..62b15bb 100644 --- a/boilingcore/imports.go +++ b/boilingcore/imports.go @@ -182,12 +182,13 @@ func newImporter() importer { "boil_queries": imports{ standard: importList{ `"database/sql"`, - `"strings"`, + `"fmt"`, }, thirdParty: importList{ `"github.com/lbryio/sqlboiler/boil"`, `"github.com/lbryio/sqlboiler/queries"`, `"github.com/lbryio/sqlboiler/queries/qm"`, + `"github.com/lbryio/sqlboiler/strmangle"`, `"github.com/pkg/errors"`, }, }, diff --git a/boilingcore/templates.go b/boilingcore/templates.go index 62c0ae5..86c78ce 100644 --- a/boilingcore/templates.go +++ b/boilingcore/templates.go @@ -8,10 +8,10 @@ import ( "strings" "text/template" - "github.com/pkg/errors" "github.com/lbryio/sqlboiler/bdb" "github.com/lbryio/sqlboiler/queries" "github.com/lbryio/sqlboiler/strmangle" + "github.com/pkg/errors" ) // templateData for sqlboiler templates diff --git a/templates/23_merge.tpl b/templates/23_merge.tpl index c7e27f5..35ca404 100644 --- a/templates/23_merge.tpl +++ b/templates/23_merge.tpl @@ -33,16 +33,33 @@ func Merge{{$tableNamePlural}}(exec boil.Executor, primaryID uint64, secondaryID return errors.New("Secondary {{$tableNameSingular}} not found") } - relatedFields := map[string]string{ + foreignKeys := []foreignKey{ {{- range .Tables -}} {{- range .FKeys -}} {{- if eq $dot.Table.Name .ForeignTable }} - "{{.Table }}": "{{ .Column}}", + {foreignTable: "{{.Table}}", foreignColumn: "{{.Column}}"}, {{- end -}} {{- end -}} {{- end }} } - err = mergeModels(tx, primaryID, secondaryID, relatedFields) + + conflictingKeys := []conflictingUniqueKey{ + {{- range .Tables -}} + {{- $table := . -}} + {{- range .FKeys -}} + {{- $fk := . -}} + {{- if eq $dot.Table.Name .ForeignTable -}} + {{- range $table.UKeys -}} + {{- if setInclude $fk.Column .Columns }} + {table: "{{$fk.Table}}", objectIdColumn: "{{$fk.Column}}", columns: []string{`{{ .Columns | join "`,`" }}`}}, + {{- end -}} + {{- end -}} + {{- end -}} + {{- end -}} + {{- end }} + } + + err = mergeModels(tx, primaryID, secondaryID, foreignKeys, conflictingKeys) if err != nil { tx.Rollback() return err diff --git a/templates/singleton/boil_queries.tpl b/templates/singleton/boil_queries.tpl index d129e43..8ff4edd 100644 --- a/templates/singleton/boil_queries.tpl +++ b/templates/singleton/boil_queries.tpl @@ -20,39 +20,97 @@ func NewQuery(exec boil.Executor, mods ...qm.QueryMod) *queries.Query { return q } -func mergeModels(tx *sql.Tx, primaryID uint64, secondaryID uint64, relatedFields map[string]string) error { - if len(relatedFields) < 1 { - return nil - } - - for table, column := range relatedFields { - // TODO: use NewQuery here, not plain sql - query := "UPDATE " + table + " SET " + column + " = ? WHERE " + column + " = ?" - _, err := tx.Exec(query, primaryID, secondaryID) - if err != nil { - return errors.WithStack(err) - } - } - return checkMerge(tx, relatedFields) -} - -func checkMerge(tx *sql.Tx, fields map[string]string) error { - columns := []interface{}{} - seenColumns := map[string]bool{} - placeholders := []string{} - for _, column := range fields { - if _, ok := seenColumns[column]; !ok { - columns = append(columns, column) - seenColumns[column] = true - placeholders = append(placeholders, "?") +func mergeModels(tx *sql.Tx, primaryID uint64, secondaryID uint64, foreignKeys []foreignKey, conflictingKeys []conflictingUniqueKey) error { + if len(foreignKeys) < 1 { + return nil + } + var err error + for _, conflict := range conflictingKeys { + err = deleteConflictsBeforeMerge(tx, conflict, primaryID, secondaryID) + if err != nil { + return errors.WithStack(err) } } - placeholder := strings.Join(placeholders, ", ") + for _, fk := range foreignKeys { + // TODO: use NewQuery here, not plain sql + query := fmt.Sprintf( + "UPDATE %s SET %s = %s WHERE %s = %s", + fk.foreignTable, fk.foreignColumn, strmangle.Placeholders(dialect.IndexPlaceholders, 1, 1, 1), + fk.foreignColumn, strmangle.Placeholders(dialect.IndexPlaceholders, 1, 2, 1), + ) + _, err = tx.Exec(query, primaryID, secondaryID) + if err != nil { + return errors.WithStack(err) + } + } + return checkMerge(tx, foreignKeys) +} - q := `SELECT table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA=DATABASE() AND column_name IN (` + placeholder + `)` - rows, err := tx.Query(q, columns...) +func deleteConflictsBeforeMerge(tx *sql.Tx, conflict conflictingUniqueKey, primaryID uint64, secondaryID uint64) error { + conflictingColumns := strmangle.SetComplement(conflict.columns, []string{conflict.objectIdColumn}) + + if len(conflictingColumns) < 1 { + return nil + } else if len(conflictingColumns) > 1 { + return errors.New("this doesnt work for unique keys with more than two columns (yet)") + } + + query := fmt.Sprintf( + "SELECT %s FROM %s WHERE %s IN (%s) GROUP BY %s HAVING count(distinct %s) > 1", + conflictingColumns[0], conflict.table, conflict.objectIdColumn, + strmangle.Placeholders(dialect.IndexPlaceholders, 2, 1, 1), + conflictingColumns[0], conflict.objectIdColumn, + ) + + rows, err := tx.Query(query, primaryID, secondaryID) + defer rows.Close() + if err != nil { + return errors.WithStack(err) + } + + args := []interface{}{secondaryID} + for rows.Next() { + var value string + err = rows.Scan(&value) + if err != nil { + return errors.WithStack(err) + } + args = append(args, value) + } + + query = fmt.Sprintf( + "DELETE FROM %s WHERE %s = %s AND %s IN (%s)", + conflict.table, conflict.objectIdColumn, strmangle.Placeholders(dialect.IndexPlaceholders, 1, 1, 1), + conflictingColumns[0], strmangle.Placeholders(dialect.IndexPlaceholders, len(args)-1, 2, 1), + ) + + _, err = tx.Exec(query, args...) + if err != nil { + return errors.WithStack(err) + } + return nil +} + +func checkMerge(tx *sql.Tx, foreignKeys []foreignKey) error { + uniqueColumns := []interface{}{} + uniqueColumnNames := map[string]bool{} + handledTablesColumns := map[string]bool{} + + for _, fk := range foreignKeys { + handledTablesColumns[fk.foreignTable+"."+fk.foreignColumn] = true + if _, ok := uniqueColumnNames[fk.foreignColumn]; !ok { + uniqueColumns = append(uniqueColumns, fk.foreignColumn) + uniqueColumnNames[fk.foreignColumn] = true + } + } + + q := fmt.Sprintf( + `SELECT table_name, column_name FROM INFORMATION_SCHEMA.COLUMNS WHERE TABLE_SCHEMA=DATABASE() AND column_name IN (%s)`, + strmangle.Placeholders(dialect.IndexPlaceholders, len(uniqueColumns), 1, 1), + ) + rows, err := tx.Query(q, uniqueColumns...) defer rows.Close() if err != nil { return errors.WithStack(err) @@ -66,7 +124,7 @@ func checkMerge(tx *sql.Tx, fields map[string]string) error { return errors.WithStack(err) } - if _, exists := fields[tableName]; !exists { + if _, exists := handledTablesColumns[tableName+"."+columnName]; !exists { return errors.New("Missing merge for " + tableName + "." + columnName) } } diff --git a/templates/singleton/boil_types.tpl b/templates/singleton/boil_types.tpl index 85e05d9..aa61bf9 100644 --- a/templates/singleton/boil_types.tpl +++ b/templates/singleton/boil_types.tpl @@ -6,6 +6,22 @@ type Nullable interface { IsZero() bool } +// foreignKey connects two tables. When merging records, foreign keys from secondary record must +// be reassigned to primary record. +type foreignKey struct { + foreignTable string + foreignColumn string +} + +// conflictingUniqueKey records a merge conflict. If two rows exist with the same value in the +// conflicting column for two records being merged, one row must be deleted. +type conflictingUniqueKey struct { + table string + objectIdColumn string + columns []string +} + + // ErrSyncFail occurs during insert when the record could not be retrieved in // order to populate default value information. This usually happens when LastInsertId // fails or there was a primary key configuration that was not resolvable.